diff --git a/controlnet_aux/__init__.py b/controlnet_aux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67c91a7cab7bdb742c5f8fdbd434a56cfd049f11 --- /dev/null +++ b/controlnet_aux/__init__.py @@ -0,0 +1,20 @@ +__version__ = "0.0.9" + +from .anyline import AnylineDetector +from .canny import CannyDetector +from .dwpose import DWposeDetector +from .hed import HEDdetector +from .leres import LeresDetector +from .lineart import LineartDetector +from .lineart_anime import LineartAnimeDetector +from .lineart_standard import LineartStandardDetector +from .mediapipe_face import MediapipeFaceDetector +from .midas import MidasDetector +from .mlsd import MLSDdetector +from .normalbae import NormalBaeDetector +from .open_pose import OpenposeDetector +from .pidi import PidiNetDetector +from .segment_anything import SamDetector +from .shuffle import ContentShuffleDetector +from .teed import TEEDdetector +from .zoe import ZoeDetector diff --git a/controlnet_aux/anyline/__init__.py b/controlnet_aux/anyline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f75b1b372c75e2c3391af525107660c96315908 --- /dev/null +++ b/controlnet_aux/anyline/__init__.py @@ -0,0 +1,118 @@ +# code based in https://github.com/TheMistoAI/ComfyUI-Anyline/blob/main/anyline.py +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image +from skimage import morphology + +from ..teed.ted import TED +from ..util import HWC3, resize_image, safe_step + + +class AnylineDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None): + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download( + pretrained_model_or_path, filename, subfolder=subfolder + ) + + model = TED() + model.load_state_dict(torch.load(model_path, map_location="cpu")) + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__( + self, + input_image, + detect_resolution=1280, + guassian_sigma=2.0, + intensity_threshold=3, + output_type="pil", + ): + device = next(iter(self.model.parameters())).device + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + original_height, original_width, _ = input_image.shape + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + height, width, _ = input_image.shape + with torch.no_grad(): + image_teed = torch.from_numpy(input_image.copy()).float().to(device) + image_teed = rearrange(image_teed, "h w c -> 1 c h w") + edges = self.model(image_teed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [ + cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) + for e in edges + ] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = safe_step(edge, 2) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + mteed_result = edge + mteed_result = HWC3(mteed_result) + + x = input_image.astype(np.float32) + g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) + intensity = np.min(g - x, axis=2).clip(0, 255) + intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) + intensity *= 127 + lineart_result = intensity.clip(0, 255).astype(np.uint8) + + lineart_result = HWC3(lineart_result) + + lineart_result = self.get_intensity_mask( + lineart_result, lower_bound=0, upper_bound=255 + ) + + cleaned = morphology.remove_small_objects( + lineart_result.astype(bool), min_size=36, connectivity=1 + ) + lineart_result = lineart_result * cleaned + final_result = self.combine_layers(mteed_result, lineart_result) + + final_result = cv2.resize( + final_result, + (original_width, original_height), + interpolation=cv2.INTER_LINEAR, + ) + + if output_type == "pil": + final_result = Image.fromarray(final_result) + + return final_result + + def get_intensity_mask(self, image_array, lower_bound, upper_bound): + mask = image_array[:, :, 0] + mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0) + mask = np.expand_dims(mask, 2).repeat(3, axis=2) + return mask + + def combine_layers(self, base_layer, top_layer): + mask = top_layer.astype(bool) + temp = 1 - (1 - top_layer) * (1 - base_layer) + result = base_layer * (~mask) + temp * mask + return result diff --git a/controlnet_aux/canny/__init__.py b/controlnet_aux/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aca9ae3a34bce509bf34e3013bae3089ef69fbbe --- /dev/null +++ b/controlnet_aux/canny/__init__.py @@ -0,0 +1,36 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from ..util import HWC3, resize_image + +class CannyDetector: + def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs): + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + detected_map = cv2.Canny(input_image, low_threshold, high_threshold) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/dwpose/__init__.py b/controlnet_aux/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34e010fe7d02daa24f4f8e2532fe80c25e5a9959 --- /dev/null +++ b/controlnet_aux/dwpose/__init__.py @@ -0,0 +1,91 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import cv2 +import torch +import numpy as np +from PIL import Image + +from ..util import HWC3, resize_image +from . import util + + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + canvas = util.draw_bodypose(canvas, candidate, subset) + canvas = util.draw_handpose(canvas, hands) + canvas = util.draw_facepose(canvas, faces) + + return canvas + +class DWposeDetector: + def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"): + from .wholebody import Wholebody + + self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) + + def to(self, device): + self.pose_estimation.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + + input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + H, W, C = input_image.shape + + with torch.no_grad(): + candidate, subset = self.pose_estimation(input_image) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + detected_map = draw_pose(pose, H, W) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/dwpose/dwpose_config/__init__.py b/controlnet_aux/dwpose/dwpose_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py b/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py new file mode 100644 index 0000000000000000000000000000000000000000..d45abe64b04716e571610853158cb448455e2e7d --- /dev/null +++ b/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py @@ -0,0 +1,257 @@ +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(288, 384), + sigma=(6., 6.93), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa + )), + head=dict( + type='RTMCCHead', + in_channels=1024, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(9, 12), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = '/data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/controlnet_aux/dwpose/util.py b/controlnet_aux/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f3ca644591351472f655b9274b882b48eb53ee --- /dev/null +++ b/controlnet_aux/dwpose/util.py @@ -0,0 +1,303 @@ +import math +import numpy as np +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + import matplotlib + + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + # (person_number*2, 21, 2) + for i in range(len(all_hand_peaks)): + peaks = all_hand_peaks[i] + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for _, keyponit in enumerate(peaks): + x, y = keyponit + + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/controlnet_aux/dwpose/wholebody.py b/controlnet_aux/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..0e92c5f913eb53c5eab685de93ce8d712eae1eec --- /dev/null +++ b/controlnet_aux/dwpose/wholebody.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import numpy as np +import warnings + +try: + import mmcv +except ImportError: + warnings.warn( + "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'" + ) + +try: + from mmpose.apis import inference_topdown + from mmpose.apis import init_model as init_pose_estimator + from mmpose.evaluation.functional import nms + from mmpose.utils import adapt_mmdet_pipeline + from mmpose.structures import merge_data_samples +except ImportError: + warnings.warn( + "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'" + ) + +try: + from mmdet.apis import inference_detector, init_detector +except ImportError: + warnings.warn( + "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'" + ) + + +class Wholebody: + def __init__(self, + det_config=None, det_ckpt=None, + pose_config=None, pose_ckpt=None, + device="cpu"): + + if det_config is None: + det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py") + + if pose_config is None: + pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py") + + if det_ckpt is None: + det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' + + if pose_ckpt is None: + pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" + + # build detector + self.detector = init_detector(det_config, det_ckpt, device=device) + self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) + + # build pose estimator + self.pose_estimator = init_pose_estimator( + pose_config, + pose_ckpt, + device=device) + + def to(self, device): + self.detector.to(device) + self.pose_estimator.to(device) + return self + + def __call__(self, oriImg): + # predict bbox + det_result = inference_detector(self.detector, oriImg) + pred_instance = det_result.pred_instances.cpu().numpy() + bboxes = np.concatenate( + (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) + bboxes = bboxes[np.logical_and(pred_instance.labels == 0, + pred_instance.scores > 0.5)] + + # set NMS threshold + bboxes = bboxes[nms(bboxes, 0.7), :4] + + # predict keypoints + if len(bboxes) == 0: + pose_results = inference_topdown(self.pose_estimator, oriImg) + else: + pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) + preds = merge_data_samples(pose_results) + preds = preds.pred_instances + + # preds = pose_results[0].pred_instances + keypoints = preds.get('transformed_keypoints', + preds.keypoints) + if 'keypoint_scores' in preds: + scores = preds.keypoint_scores + else: + scores = np.ones(keypoints.shape[:-1]) + + if 'keypoints_visible' in preds: + visible = preds.keypoints_visible + else: + visible = np.ones(keypoints.shape[:-1]) + keypoints_info = np.concatenate( + (keypoints, scores[..., None], visible[..., None]), + axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores, visible = keypoints_info[ + ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + + return keypoints, scores diff --git a/controlnet_aux/dwpose/yolox_config/__init__.py b/controlnet_aux/dwpose/yolox_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py b/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4cb5a4bbc1953f3d7b6ea84d130f8312083914 --- /dev/null +++ b/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py @@ -0,0 +1,245 @@ +img_scale = (640, 640) # width, height + +# model settings +model = dict( + type='YOLOX', + data_preprocessor=dict( + type='DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + backbone=dict( + type='CSPDarknet', + deepen_factor=1.0, + widen_factor=1.0, + out_indices=(2, 3, 4), + use_depthwise=False, + spp_kernal_sizes=(5, 9, 13), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + ), + neck=dict( + type='YOLOXPAFPN', + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + use_depthwise=False, + upsample_cfg=dict(scale_factor=2, mode='nearest'), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish')), + bbox_head=dict( + type='YOLOXHead', + num_classes=80, + in_channels=256, + feat_channels=256, + stacked_convs=2, + strides=(8, 16, 32), + use_depthwise=False, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox=dict( + type='IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +# dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.1, 2), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict( + type='MixUp', + img_scale=img_scale, + ratio_range=(0.8, 1.6), + pad_val=114.0), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + # According to the official implementation, multi-scale + # training is not considered here but in the + # 'mmdet/models/detectors/yolox.py'. + # Resize and Pad are for the last 15 epochs when Mosaic, + # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + # If the image is three-channel, the pad value needs + # to be set separately for each channel. + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='PackDetInputs') +] + +train_dataset = dict( + # use MultiImageMixDataset wrapper to support mosaic and mixup + type='MultiImageMixDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=[ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True) + ], + filter_cfg=dict(filter_empty_gt=False, min_size=32), + backend_args=backend_args), + pipeline=train_pipeline) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) +val_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + backend_args=backend_args) +test_evaluator = val_evaluator + +# training settings +max_epochs = 300 +num_last_epochs = 15 +interval = 10 + +train_cfg = dict(max_epochs=max_epochs, val_interval=interval) + +# optimizer +# default 8 gpu +base_lr = 0.01 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, + nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) + +# learning rate +param_scheduler = [ + dict( + # use quadratic formula to warm up 5 epochs + # and lr is updated by iteration + # TODO: fix default scope in get function + type='mmdet.QuadraticWarmupLR', + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + # use cosine lr from 5 to 285 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + # use fixed lr during last 15 epochs + type='ConstantLR', + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ) +] + +default_hooks = dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + )) + +custom_hooks = [ + dict( + type='YOLOXModeSwitchHook', + num_last_epochs=num_last_epochs, + priority=48), + dict(type='SyncNormHook', priority=48), + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0001, + update_buffers=True, + priority=49) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/controlnet_aux/hed/__init__.py b/controlnet_aux/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fc8525ab133e4a9979ab517535fc16d8a9e39d --- /dev/null +++ b/controlnet_aux/hed/__init__.py @@ -0,0 +1,129 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, nms, resize_image, safe_step + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + +class HEDdetector: + def __init__(self, netNetwork): + self.netNetwork = netNetwork + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): + filename = filename or "ControlNetHED.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + netNetwork = ControlNetHED_Apache2() + netNetwork.load_state_dict(torch.load(model_path, map_location='cpu')) + netNetwork.float().eval() + + return cls(netNetwork) + + def to(self, device): + self.netNetwork.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.netNetwork.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().to(device) + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/leres/__init__.py b/controlnet_aux/leres/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d7728cce50279dd265ca4b8e88aa9abdd65f07 --- /dev/null +++ b/controlnet_aux/leres/__init__.py @@ -0,0 +1,118 @@ +import os + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from .leres.depthmap import estimateboost, estimateleres +from .leres.multi_depth_model_woauxi import RelDepthModel +from .leres.net_tools import strip_prefix_if_present +from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel +from .pix2pix.options.test_options import TestOptions + + +class LeresDetector: + def __init__(self, model, pix2pixmodel): + self.model = model + self.pix2pixmodel = pix2pixmodel + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None, local_files_only=False): + filename = filename or "res101.pth" + pix2pix_filename = pix2pix_filename or "latest_net_G.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + checkpoint = torch.load(model_path, map_location=torch.device('cpu')) + + model = RelDepthModel(backbone='resnext101') + model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True) + del checkpoint + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, pix2pix_filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir, local_files_only=local_files_only) + + opt = TestOptions().parse() + if not torch.cuda.is_available(): + opt.gpu_ids = [] # cpu mode + pix2pixmodel = Pix2Pix4DepthModel(opt) + pix2pixmodel.save_dir = os.path.dirname(model_path) + pix2pixmodel.load_networks('latest') + pix2pixmodel.eval() + + return cls(model, pix2pixmodel) + + def to(self, device): + self.model.to(device) + # TODO - refactor pix2pix implementation to support device migration + # self.pix2pixmodel.to(device) + return self + + def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"): + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + height, width, dim = input_image.shape + + with torch.no_grad(): + + if boost: + depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height)) + else: + depth = estimateleres(input_image, self.model, width, height) + + numbytes=2 + depth_min = depth.min() + depth_max = depth.max() + max_val = (2**(8*numbytes))-1 + + # check output before normalizing and mapping to 16 bit + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape) + + # single channel, 16 bit image + depth_image = out.astype("uint16") + + # convert to uint8 + depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0)) + + # remove near + if thr_a != 0: + thr_a = ((thr_a/100)*255) + depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1] + + # invert image + depth_image = cv2.bitwise_not(depth_image) + + # remove bg + if thr_b != 0: + thr_b = ((thr_b/100)*255) + depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1] + + detected_map = depth_image + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/leres/leres/LICENSE b/controlnet_aux/leres/leres/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e0f1d07d98d4e85e684734d058dfe2515d215405 --- /dev/null +++ b/controlnet_aux/leres/leres/LICENSE @@ -0,0 +1,23 @@ +https://github.com/thygate/stable-diffusion-webui-depthmap-script + +MIT License + +Copyright (c) 2023 Bob Thiry + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/controlnet_aux/leres/leres/Resnet.py b/controlnet_aux/leres/leres/Resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f12c9975c1aa05401269be3ca3dbaa56bde55581 --- /dev/null +++ b/controlnet_aux/leres/leres/Resnet.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn as NN + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + #self.avgpool = nn.AvgPool2d(7, stride=1) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + features = [] + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + x = self.layer2(x) + features.append(x) + x = self.layer3(x) + features.append(x) + x = self.layer4(x) + features.append(x) + + return features + + +def resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def resnet34(pretrained=True, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + return model + + +def resnet101(pretrained=True, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + return model + + +def resnet152(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + return model diff --git a/controlnet_aux/leres/leres/Resnext_torch.py b/controlnet_aux/leres/leres/Resnext_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9af54fcc3e5b363935ef60c8aaf269110c0d6611 --- /dev/null +++ b/controlnet_aux/leres/leres/Resnext_torch.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# coding: utf-8 +import torch.nn as nn + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + +__all__ = ['resnext101_32x8d'] + + +model_urls = { + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + features = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + + x = self.layer2(x) + features.append(x) + + x = self.layer3(x) + features.append(x) + + x = self.layer4(x) + features.append(x) + + #x = self.avgpool(x) + #x = torch.flatten(x, 1) + #x = self.fc(x) + + return features + + def forward(self, x): + return self._forward_impl(x) + + + +def resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + return model + diff --git a/controlnet_aux/leres/leres/__init__.py b/controlnet_aux/leres/leres/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/leres/leres/depthmap.py b/controlnet_aux/leres/leres/depthmap.py new file mode 100644 index 0000000000000000000000000000000000000000..fc743bf4946b514a53f8d286a395e33c7b612582 --- /dev/null +++ b/controlnet_aux/leres/leres/depthmap.py @@ -0,0 +1,548 @@ +# Author: thygate +# https://github.com/thygate/stable-diffusion-webui-depthmap-script + +import gc +from operator import getitem + +import cv2 +import numpy as np +import skimage.measure +import torch +from torchvision.transforms import transforms + +from ...util import torch_gc + +whole_size_threshold = 1600 # R_max from the paper +pix2pixsize = 1024 + +def scale_torch(img): + """ + Scale the image and output it in torch.tensor. + :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W] + :param scale: the scale factor. float + :return: img. [C, H, W] + """ + if len(img.shape) == 2: + img = img[np.newaxis, :, :] + if img.shape[2] == 3: + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )]) + img = transform(img.astype(np.float32)) + else: + img = img.astype(np.float32) + img = torch.from_numpy(img) + return img + +def estimateleres(img, model, w, h): + device = next(iter(model.parameters())).device + # leres transform input + rgb_c = img[:, :, ::-1].copy() + A_resize = cv2.resize(rgb_c, (w, h)) + img_torch = scale_torch(A_resize)[None, :, :, :] + + # compute + with torch.no_grad(): + img_torch = img_torch.to(device) + prediction = model.depth_model(img_torch) + + prediction = prediction.squeeze().cpu().numpy() + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + return prediction + +def generatemask(size): + # Generates a Guassian mask + mask = np.zeros(size, dtype=np.float32) + sigma = int(size[0]/16) + k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) + mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1 + mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) + mask = (mask - mask.min()) / (mask.max() - mask.min()) + mask = mask.astype(np.float32) + return mask + +def resizewithpool(img, size): + i_size = img.shape[0] + n = int(np.floor(i_size/size)) + + out = skimage.measure.block_reduce(img, (n, n), np.max) + return out + +def rgb2gray(rgb): + # Converts rgb to gray + return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140]) + +def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000): + # Returns the R_x resolution described in section 5 of the main paper. + + # Parameters: + # img :input rgb image + # basesize : size the dilation kernel which is equal to receptive field of the network. + # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue. + # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3. + # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper) + + # Returns: + # outputsize_scale*speed_scale :The computed R_x resolution + # patch_scale: K parameter from section 6 of the paper + + # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search + speed_scale = 32 + image_dim = int(min(img.shape[0:2])) + + gray = rgb2gray(img) + grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)) + grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA) + + # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues + m = grad.min() + M = grad.max() + middle = m + (0.4 * (M - m)) + grad[grad < middle] = 0 + grad[grad >= middle] = 1 + + # dilation kernel with size of the receptive field + kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float) + # dilation kernel with size of the a quarter of receptive field used to compute k + # as described in section 6 of main paper + kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float) + + # Output resolution limit set by the whole_size_threshold and scale_threshold. + threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2])) + + outputsize_scale = basesize / speed_scale + for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))): + grad_resized = resizewithpool(grad, p_size) + grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST) + grad_resized[grad_resized >= 0.5] = 1 + grad_resized[grad_resized < 0.5] = 0 + + dilated = cv2.dilate(grad_resized, kernel, iterations=1) + meanvalue = (1-dilated).mean() + if meanvalue > confidence: + break + else: + outputsize_scale = p_size + + grad_region = cv2.dilate(grad_resized, kernel2, iterations=1) + patch_scale = grad_region.mean() + + return int(outputsize_scale*speed_scale), patch_scale + +# Generate a double-input depth estimation +def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel): + # Generate the low resolution estimation + estimate1 = singleestimate(img, size1, model, net_type) + # Resize to the inference size of merge network. + estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Generate the high resolution estimation + estimate2 = singleestimate(img, size2, model, net_type) + # Resize to the inference size of merge network. + estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Inference on the merge model + pix2pixmodel.set_input(estimate1, estimate2) + pix2pixmodel.test() + visuals = pix2pixmodel.get_current_visuals() + prediction_mapped = visuals['fake_B'] + prediction_mapped = (prediction_mapped+1)/2 + prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / ( + torch.max(prediction_mapped) - torch.min(prediction_mapped)) + prediction_mapped = prediction_mapped.squeeze().cpu().numpy() + + return prediction_mapped + +# Generate a single-input depth estimation +def singleestimate(img, msize, model, net_type): + # if net_type == 0: + return estimateleres(img, model, msize, msize) + # else: + # return estimatemidasBoost(img, model, msize, msize) + +def applyGridpatch(blsize, stride, img, box): + # Extract a simple grid patch. + counter1 = 0 + patch_bound_list = {} + for k in range(blsize, img.shape[1] - blsize, stride): + for j in range(blsize, img.shape[0] - blsize, stride): + patch_bound_list[str(counter1)] = {} + patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize] + patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1], + patchbounds[2] - patchbounds[0]] + patch_bound_list[str(counter1)]['rect'] = patch_bound + patch_bound_list[str(counter1)]['size'] = patch_bound[2] + counter1 = counter1 + 1 + return patch_bound_list + +# Generating local patches to perform the local refinement described in section 6 of the main paper. +def generatepatchs(img, base_size): + + # Compute the gradients as a proxy of the contextual cues. + img_gray = rgb2gray(img) + whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\ + np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)) + + threshold = whole_grad[whole_grad > 0].mean() + whole_grad[whole_grad < threshold] = 0 + + # We use the integral image to speed-up the evaluation of the amount of gradients for each patch. + gf = whole_grad.sum()/len(whole_grad.reshape(-1)) + grad_integral_image = cv2.integral(whole_grad) + + # Variables are selected such that the initial patch size would be the receptive field size + # and the stride is set to 1/3 of the receptive field size. + blsize = int(round(base_size/2)) + stride = int(round(blsize*0.75)) + + # Get initial Grid + patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0]) + + # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine + # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map. + print("Selecting patches ...") + patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf) + + # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest + # patch + patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True) + return patchset + +def getGF_fromintegral(integralimage, rect): + # Computes the gradient density of a given patch from the gradient integral image. + x1 = rect[1] + x2 = rect[1]+rect[3] + y1 = rect[0] + y2 = rect[0]+rect[2] + value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1] + return value + +# Adaptively select patches +def adaptiveselection(integral_grad, patch_bound_list, gf): + patchlist = {} + count = 0 + height, width = integral_grad.shape + + search_step = int(32/factor) + + # Go through all patches + for c in range(len(patch_bound_list)): + # Get patch + bbox = patch_bound_list[str(c)]['rect'] + + # Compute the amount of gradients present in the patch from the integral image. + cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3]) + + # Check if patching is beneficial by comparing the gradient density of the patch to + # the gradient density of the whole image + if cgf >= gf: + bbox_test = bbox.copy() + patchlist[str(count)] = {} + + # Enlarge each patch until the gradient density of the patch is equal + # to the whole image gradient density + while True: + + bbox_test[0] = bbox_test[0] - int(search_step/2) + bbox_test[1] = bbox_test[1] - int(search_step/2) + + bbox_test[2] = bbox_test[2] + search_step + bbox_test[3] = bbox_test[3] + search_step + + # Check if we are still within the image + if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \ + or bbox_test[0] + bbox_test[2] >= width: + break + + # Compare gradient density + cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3]) + if cgf < gf: + break + bbox = bbox_test.copy() + + # Add patch to selected patches + patchlist[str(count)]['rect'] = bbox + patchlist[str(count)]['size'] = bbox[2] + count = count + 1 + + # Return selected patches + return patchlist + +def impatch(image, rect): + # Extract the given patch pixels from a given image. + w1 = rect[0] + h1 = rect[1] + w2 = w1 + rect[2] + h2 = h1 + rect[3] + image_patch = image[h1:h2, w1:w2] + return image_patch + +class ImageandPatchs: + def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1): + self.root_dir = root_dir + self.patchsinfo = patchsinfo + self.name = name + self.patchs = patchsinfo + self.scale = scale + + self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)), + interpolation=cv2.INTER_CUBIC) + + self.do_have_estimate = False + self.estimation_updated_image = None + self.estimation_base_image = None + + def __len__(self): + return len(self.patchs) + + def set_base_estimate(self, est): + self.estimation_base_image = est + if self.estimation_updated_image is not None: + self.do_have_estimate = True + + def set_updated_estimate(self, est): + self.estimation_updated_image = est + if self.estimation_base_image is not None: + self.do_have_estimate = True + + def __getitem__(self, index): + patch_id = int(self.patchs[index][0]) + rect = np.array(self.patchs[index][1]['rect']) + msize = self.patchs[index][1]['size'] + + ## applying scale to rect: + rect = np.round(rect * self.scale) + rect = rect.astype('int') + msize = round(msize * self.scale) + + patch_rgb = impatch(self.rgb_image, rect) + if self.do_have_estimate: + patch_whole_estimate_base = impatch(self.estimation_base_image, rect) + patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect) + return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base, + 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect, + 'size': msize, 'id': patch_id} + else: + return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id} + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + """ + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + """ + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + #self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + #if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt + + +def estimateboost(img, model, model_type, pix2pixmodel, max_res=512, depthmap_script_boost_rmax=None): + global whole_size_threshold + + # get settings + if depthmap_script_boost_rmax: + whole_size_threshold = depthmap_script_boost_rmax + + if model_type == 0: #leres + net_receptive_field_size = 448 + patch_netsize = 2 * net_receptive_field_size + elif model_type == 1: #dpt_beit_large_512 + net_receptive_field_size = 512 + patch_netsize = 2 * net_receptive_field_size + else: #other midas + net_receptive_field_size = 384 + patch_netsize = 2 * net_receptive_field_size + + gc.collect() + torch_gc() + + # Generate mask used to smoothly blend the local pathc estimations to the base estimate. + # It is arbitrarily large to avoid artifacts during rescaling for each crop. + mask_org = generatemask((3000, 3000)) + mask = mask_org.copy() + + # Value x of R_x defined in the section 5 of the main paper. + r_threshold_value = 0.2 + #if R0: + # r_threshold_value = 0 + + input_resolution = img.shape + scale_threshold = 3 # Allows up-scaling with a scale up to 3 + + # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the + # supplementary material. + whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold) + + # print('wholeImage being processed in :', whole_image_optimal_size) + + # Generate the base estimate using the double estimation. + whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel) + + # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select + # small high-density regions of the image. + global factor + factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2) + # print('Adjust factor is:', 1/factor) + + # Check if Local boosting is beneficial. + if max_res < whole_image_optimal_size: + # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result") + return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) + + # Compute the default target resolution. + if img.shape[0] > img.shape[1]: + a = 2 * whole_image_optimal_size + b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0]) + else: + a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1]) + b = 2 * whole_image_optimal_size + b = int(round(b / factor)) + a = int(round(a / factor)) + + """ + # recompute a, b and saturate to max res. + if max(a,b) > max_res: + print('Default Res is higher than max-res: Reducing final resolution') + if img.shape[0] > img.shape[1]: + a = max_res + b = round(max_res * img.shape[1] / img.shape[0]) + else: + a = round(max_res * img.shape[0] / img.shape[1]) + b = max_res + b = int(b) + a = int(a) + """ + + img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC) + + # Extract selected patches for local refinement + base_size = net_receptive_field_size * 2 + patchset = generatepatchs(img, base_size) + + # print('Target resolution: ', img.shape) + + # Computing a scale in case user prompted to generate the results as the same resolution of the input. + # Notice that our method output resolution is independent of the input resolution and this parameter will only + # enable a scaling operation during the local patch merge implementation to generate results with the same resolution + # as the input. + """ + if output_resolution == 1: + mergein_scale = input_resolution[0] / img.shape[0] + print('Dynamicly change merged-in resolution; scale:', mergein_scale) + else: + mergein_scale = 1 + """ + # always rescale to input res for now + mergein_scale = input_resolution[0] / img.shape[0] + + imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale) + whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale), + round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC) + imageandpatchs.set_base_estimate(whole_estimate_resized.copy()) + imageandpatchs.set_updated_estimate(whole_estimate_resized.copy()) + + print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2]) + print('Patches to process: '+str(len(imageandpatchs))) + + # Enumerate through all patches, generate their estimations and refining the base estimate. + for patch_ind in range(len(imageandpatchs)): + + # Get patch information + patch = imageandpatchs[patch_ind] # patch object + patch_rgb = patch['patch_rgb'] # rgb patch + patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base + rect = patch['rect'] # patch size and location + patch_id = patch['id'] # patch ID + org_size = patch_whole_estimate_base.shape # the original size from the unscaled input + print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect) + + # We apply double estimation for patches. The high resolution value is fixed to twice the receptive + # field size of the network for patches to accelerate the process. + patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel) + patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Merging the patch estimation into the base estimate using our merge network: + # We feed the patch estimation and the same region from the updated base estimate to the merge network + # to generate the target estimate for the corresponding region. + pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation) + + # Run merging network + pix2pixmodel.test() + visuals = pix2pixmodel.get_current_visuals() + + prediction_mapped = visuals['fake_B'] + prediction_mapped = (prediction_mapped+1)/2 + prediction_mapped = prediction_mapped.squeeze().cpu().numpy() + + mapped = prediction_mapped + + # We use a simple linear polynomial to make sure the result of the merge network would match the values of + # base estimate + p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1) + merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape) + + merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC) + + # Get patch size and location + w1 = rect[0] + h1 = rect[1] + w2 = w1 + rect[2] + h2 = h1 + rect[3] + + # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size + # and resize it to our needed size while merging the patches. + if mask.shape != org_size: + mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR) + + tobemergedto = imageandpatchs.estimation_updated_image + + # Update the whole estimation: + # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless + # blending at the boundaries of the patch region. + tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask) + imageandpatchs.set_updated_estimate(tobemergedto) + + # output + return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) diff --git a/controlnet_aux/leres/leres/multi_depth_model_woauxi.py b/controlnet_aux/leres/leres/multi_depth_model_woauxi.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf35d7843e00be5d3c831d72b9ab5d64d130f93 --- /dev/null +++ b/controlnet_aux/leres/leres/multi_depth_model_woauxi.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from . import network_auxi as network +from .net_tools import get_func + + +class RelDepthModel(nn.Module): + def __init__(self, backbone='resnet50'): + super(RelDepthModel, self).__init__() + if backbone == 'resnet50': + encoder = 'resnet50_stride32' + elif backbone == 'resnext101': + encoder = 'resnext101_stride32x8d' + self.depth_model = DepthModel(encoder) + + def inference(self, rgb): + with torch.no_grad(): + input = rgb.to(self.depth_model.device) + depth = self.depth_model(input) + #pred_depth_out = depth - depth.min() + 0.01 + return depth #pred_depth_out + + +class DepthModel(nn.Module): + def __init__(self, encoder): + super(DepthModel, self).__init__() + backbone = network.__name__.split('.')[-1] + '.' + encoder + self.encoder_modules = get_func(backbone)() + self.decoder_modules = network.Decoder() + + def forward(self, x): + lateral_out = self.encoder_modules(x) + out_logit = self.decoder_modules(lateral_out) + return out_logit \ No newline at end of file diff --git a/controlnet_aux/leres/leres/net_tools.py b/controlnet_aux/leres/leres/net_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..2f213315046e078bb861d65d3ef4a6fc446e945d --- /dev/null +++ b/controlnet_aux/leres/leres/net_tools.py @@ -0,0 +1,54 @@ +import importlib +import torch +import os +from collections import OrderedDict + + +def get_func(func_name): + """Helper to return a function object by name. func_name must identify a + function in this module or the path to a function relative to the base + 'modeling' module. + """ + if func_name == '': + return None + try: + parts = func_name.split('.') + # Refers to a function in this module + if len(parts) == 1: + return globals()[parts[0]] + # Otherwise, assume we're referencing a module under modeling + module_name = 'controlnet_aux.leres.leres.' + '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + return getattr(module, parts[-1]) + except Exception: + print('Failed to f1ind function: %s', func_name) + raise + +def load_ckpt(args, depth_model, shift_model, focal_model): + """ + Load checkpoint. + """ + if os.path.isfile(args.load_ckpt): + print("loading checkpoint %s" % args.load_ckpt) + checkpoint = torch.load(args.load_ckpt) + if shift_model is not None: + shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), + strict=True) + if focal_model is not None: + focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), + strict=True) + depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), + strict=True) + del checkpoint + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "")] = value + return stripped_state_dict \ No newline at end of file diff --git a/controlnet_aux/leres/leres/network_auxi.py b/controlnet_aux/leres/leres/network_auxi.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd87011a5339aca632d1a10b217c8737bdc794f --- /dev/null +++ b/controlnet_aux/leres/leres/network_auxi.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + +from . import Resnet, Resnext_torch + + +def resnet50_stride32(): + return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2]) + +def resnext101_stride32x8d(): + return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2]) + + +class Decoder(nn.Module): + def __init__(self): + super(Decoder, self).__init__() + self.inchannels = [256, 512, 1024, 2048] + self.midchannels = [256, 256, 256, 512] + self.upfactors = [2,2,2,2] + self.outchannels = 1 + + self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3]) + self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True) + self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True) + + self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2]) + self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1]) + self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0]) + + self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2) + self._init_params() + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, features): + x_32x = self.conv(features[3]) # 1/32 + x_32 = self.conv1(x_32x) + x_16 = self.upsample(x_32) # 1/16 + + x_8 = self.ffm2(features[2], x_16) # 1/8 + x_4 = self.ffm1(features[1], x_8) # 1/4 + x_2 = self.ffm0(features[0], x_4) # 1/2 + #----------------------------------------- + x = self.outconv(x_2) # original size + return x + +class DepthNet(nn.Module): + __factory = { + 18: Resnet.resnet18, + 34: Resnet.resnet34, + 50: Resnet.resnet50, + 101: Resnet.resnet101, + 152: Resnet.resnet152 + } + def __init__(self, + backbone='resnet', + depth=50, + upfactors=[2, 2, 2, 2]): + super(DepthNet, self).__init__() + self.backbone = backbone + self.depth = depth + self.pretrained = False + self.inchannels = [256, 512, 1024, 2048] + self.midchannels = [256, 256, 256, 512] + self.upfactors = upfactors + self.outchannels = 1 + + # Build model + if self.backbone == 'resnet': + if self.depth not in DepthNet.__factory: + raise KeyError("Unsupported depth:", self.depth) + self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained) + elif self.backbone == 'resnext101_32x8d': + self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained) + else: + self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained) + + def forward(self, x): + x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4 + return x + + +class FTB(nn.Module): + def __init__(self, inchannels, midchannels=512): + super(FTB, self).__init__() + self.in1 = inchannels + self.mid = midchannels + self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, + bias=True) + # NN.BatchNorm2d + self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, + padding=1, stride=1, bias=True), \ + nn.BatchNorm2d(num_features=self.mid), \ + nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, + padding=1, stride=1, bias=True)) + self.relu = nn.ReLU(inplace=True) + + self.init_params() + + def forward(self, x): + x = self.conv1(x) + x = x + self.conv_branch(x) + x = self.relu(x) + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class ATA(nn.Module): + def __init__(self, inchannels, reduction=8): + super(ATA, self).__init__() + self.inchannels = inchannels + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction), + nn.ReLU(inplace=True), + nn.Linear(self.inchannels // reduction, self.inchannels), + nn.Sigmoid()) + self.init_params() + + def forward(self, low_x, high_x): + n, c, _, _ = low_x.size() + x = torch.cat([low_x, high_x], 1) + x = self.avg_pool(x) + x = x.view(n, -1) + x = self.fc(x).view(n, c, 1, 1) + x = low_x * x + high_x + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + # init.normal(m.weight, std=0.01) + init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + # init.normal_(m.weight, std=0.01) + init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class FFM(nn.Module): + def __init__(self, inchannels, midchannels, outchannels, upfactor=2): + super(FFM, self).__init__() + self.inchannels = inchannels + self.midchannels = midchannels + self.outchannels = outchannels + self.upfactor = upfactor + + self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) + # self.ata = ATA(inchannels = self.midchannels) + self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) + + self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) + + self.init_params() + + def forward(self, low_x, high_x): + x = self.ftb1(low_x) + x = x + high_x + x = self.ftb2(x) + x = self.upsample(x) + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class AO(nn.Module): + # Adaptive output module + def __init__(self, inchannels, outchannels, upfactor=2): + super(AO, self).__init__() + self.inchannels = inchannels + self.outchannels = outchannels + self.upfactor = upfactor + + self.adapt_conv = nn.Sequential( + nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1, + stride=1, bias=True), \ + nn.BatchNorm2d(num_features=self.inchannels // 2), \ + nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1, + stride=1, bias=True), \ + nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)) + + self.init_params() + + def forward(self, x): + x = self.adapt_conv(x) + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + + +# ============================================================================================================== + + +class ResidualConv(nn.Module): + def __init__(self, inchannels): + super(ResidualConv, self).__init__() + # NN.BatchNorm2d + self.conv = nn.Sequential( + # nn.BatchNorm2d(num_features=inchannels), + nn.ReLU(inplace=False), + # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True), + # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True) + nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1, + bias=False), + nn.BatchNorm2d(num_features=inchannels / 2), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, + bias=False) + ) + self.init_params() + + def forward(self, x): + x = self.conv(x) + x + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class FeatureFusion(nn.Module): + def __init__(self, inchannels, outchannels): + super(FeatureFusion, self).__init__() + self.conv = ResidualConv(inchannels=inchannels) + # NN.BatchNorm2d + self.up = nn.Sequential(ResidualConv(inchannels=inchannels), + nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3, + stride=2, padding=1, output_padding=1), + nn.BatchNorm2d(num_features=outchannels), + nn.ReLU(inplace=True)) + + def forward(self, lowfeat, highfeat): + return self.up(highfeat + self.conv(lowfeat)) + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class SenceUnderstand(nn.Module): + def __init__(self, channels): + super(SenceUnderstand, self).__init__() + self.channels = channels + self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), + nn.ReLU(inplace=True)) + self.pool = nn.AdaptiveAvgPool2d(8) + self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels), + nn.ReLU(inplace=True)) + self.conv2 = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0), + nn.ReLU(inplace=True)) + self.initial_params() + + def forward(self, x): + n, c, h, w = x.size() + x = self.conv1(x) + x = self.pool(x) + x = x.view(n, -1) + x = self.fc(x) + x = x.view(n, self.channels, 1, 1) + x = self.conv2(x) + x = x.repeat(1, 1, h, w) + return x + + def initial_params(self, dev=0.01): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # print torch.sum(m.weight) + m.weight.data.normal_(0, dev) + if m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.ConvTranspose2d): + # print torch.sum(m.weight) + m.weight.data.normal_(0, dev) + if m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, dev) + + +if __name__ == '__main__': + net = DepthNet(depth=50, pretrained=True) + print(net) + inputs = torch.ones(4,3,128,128) + out = net(inputs) + print(out.size()) + diff --git a/controlnet_aux/leres/pix2pix/LICENSE b/controlnet_aux/leres/pix2pix/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..38b1a24fd389a138b930dcf1ee606ef97a0186c8 --- /dev/null +++ b/controlnet_aux/leres/pix2pix/LICENSE @@ -0,0 +1,19 @@ +https://github.com/compphoto/BoostingMonocularDepth + +Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved. + +This software is for academic use only. A redistribution of this +software, with or without modifications, has to be for academic +use only, while giving the appropriate credit to the original +authors of the software. The methods implemented as a part of +this software may be covered under patents or patent applications. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/controlnet_aux/leres/pix2pix/__init__.py b/controlnet_aux/leres/pix2pix/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/leres/pix2pix/models/__init__.py b/controlnet_aux/leres/pix2pix/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..301c966fca7a375c359b7ee7d455e23ee82ebb64 --- /dev/null +++ b/controlnet_aux/leres/pix2pix/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from .base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "controlnet_aux.leres.pix2pix.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/controlnet_aux/leres/pix2pix/models/base_model.py b/controlnet_aux/leres/pix2pix/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..66ec298f77cf769e39da38d1107e0b6dc38d519d --- /dev/null +++ b/controlnet_aux/leres/pix2pix/models/base_model.py @@ -0,0 +1,244 @@ +import gc +import os +from abc import ABC, abstractmethod +from collections import OrderedDict + +import torch + +from ....util import torch_gc +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def unload_network(self, name): + """Unload network and gc. + """ + if isinstance(name, str): + net = getattr(self, 'net' + name) + del net + gc.collect() + torch_gc() + return None + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + # print('Loading depth boost model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/controlnet_aux/leres/pix2pix/models/base_model_hg.py b/controlnet_aux/leres/pix2pix/models/base_model_hg.py new file mode 100644 index 0000000000000000000000000000000000000000..1709accdf0b048b3793dfd1f58d1b06c35f7b907 --- /dev/null +++ b/controlnet_aux/leres/pix2pix/models/base_model_hg.py @@ -0,0 +1,58 @@ +import os +import torch + +class BaseModelHG(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '_%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda(device_id=gpu_ids[0]) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print(save_path) + model = torch.load(save_path) + return model + # network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass diff --git a/controlnet_aux/leres/pix2pix/models/networks.py b/controlnet_aux/leres/pix2pix/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf912b2973721a02deefd042af621e732bad59f --- /dev/null +++ b/controlnet_aux/leres/pix2pix/models/networks.py @@ -0,0 +1,623 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + + +############################################################################### +# Helper Functions +############################################################################### + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + # print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Create a generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + norm (str) -- the name of normalization layers used in the network: batch | instance | none + use_dropout (bool) -- if use dropout layers. + init_type (str) -- the name of our initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a generator + + Our current implementation provides two types of generators: + U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) + The original U-Net paper: https://arxiv.org/abs/1505.04597 + + Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) + Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). + + + The generator has been initialized by . It uses RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'resnet_9blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + elif netG == 'resnet_6blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + elif netG == 'resnet_12blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12) + elif netG == 'unet_128': + net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_256': + net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_672': + net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_960': + net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_1024': + net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + return init_net(net, init_type, init_gain, gpu_ids) + + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): + """Create a discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the first conv layer + netD (str) -- the architecture's name: basic | n_layers | pixel + n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' + norm (str) -- the type of normalization layers used in the network. + init_type (str) -- the name of the initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a discriminator + + Our current implementation provides three types of discriminators: + [basic]: 'PatchGAN' classifier described in the original pix2pix paper. + It can classify whether 70×70 overlapping patches are real or fake. + Such a patch-level discriminator architecture has fewer parameters + than a full-image discriminator and can work on arbitrarily-sized images + in a fully convolutional fashion. + + [n_layers]: With this mode, you can specify the number of conv layers in the discriminator + with the parameter (default=3 as used in [basic] (PatchGAN).) + + [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. + It encourages greater color diversity but has no effect on spatial statistics. + + The discriminator has been initialized by . It uses Leakly RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': # default PatchGAN classifier + net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + elif netD == 'n_layers': # more options + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + elif netD == 'pixel': # classify if each pixel is real or fake + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + + Parameters: + prediction (tensor) - - tpyically the prediction from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + A label tensor filled with ground truth label, and with the size of the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + + Parameters: + prediction (tensor) - - tpyically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + the calculated loss. + """ + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): + """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 + + Arguments: + netD (network) -- discriminator network + real_data (tensor array) -- real images + fake_data (tensor array) -- generated images from the generator + device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + type (str) -- if we mix real and fake data or not [real | fake | mixed]. + constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 + lambda_gp (float) -- weight for this loss + + Returns the gradient penalty loss + """ + if lambda_gp > 0.0: + if type == 'real': # either use real images, fake images, or a linear interpolation of two. + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device=device) + alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +class ResnetGenerator(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) diff --git a/controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py b/controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..89e89652feb96314973a050c5a2477b474630abb --- /dev/null +++ b/controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py @@ -0,0 +1,155 @@ +import torch +from .base_model import BaseModel +from . import networks + + +class Pix2Pix4DepthModel(BaseModel): + """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. + + The model training requires '--dataset_mode aligned' dataset. + By default, it uses a '--netG unet256' U-Net generator, + a '--netD basic' discriminator (PatchGAN), + and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). + + pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf + """ + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + + For pix2pix, we do not use image buffer + The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 + By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. + """ + # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) + parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla',) + parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss') + return parser + + def __init__(self, opt): + """Initialize the pix2pix class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + + self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] + # self.loss_names = ['G_L1'] + + # specify the images you want to save/display. The training/test scripts will call + if self.isTrain: + self.visual_names = ['outer','inner', 'fake_B', 'real_B'] + else: + self.visual_names = ['fake_B'] + + # specify the models you want to save to the disk. The training/test scripts will call and + if self.isTrain: + self.model_names = ['G','D'] + else: # during test time, only load G + self.model_names = ['G'] + + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none', + False, 'normal', 0.02, self.gpu_ids) + + if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + + def set_input_train(self, input): + self.outer = input['data_outer'].to(self.device) + self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False) + + self.inner = input['data_inner'].to(self.device) + self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False) + + self.image_paths = input['image_path'] + + if self.isTrain: + self.gtfake = input['data_gtfake'].to(self.device) + self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False) + self.real_B = self.gtfake + + self.real_A = torch.cat((self.outer, self.inner), 1) + + def set_input(self, outer, inner): + inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0) + outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0) + + inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner)) + outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer)) + + inner = self.normalize(inner) + outer = self.normalize(outer) + + self.real_A = torch.cat((outer, inner), 1).to(self.device) + + + def normalize(self, input): + input = input * 2 + input = input - 1 + return input + + def forward(self): + """Run forward pass; called by both functions and .""" + self.fake_B = self.netG(self.real_A) # G(A) + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_B + fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + self.loss_D.backward() + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + # combine loss and calculate gradients + self.loss_G = self.loss_G_L1 + self.loss_G_GAN + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + # update G + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights \ No newline at end of file diff --git a/controlnet_aux/leres/pix2pix/options/__init__.py b/controlnet_aux/leres/pix2pix/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90 --- /dev/null +++ b/controlnet_aux/leres/pix2pix/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/controlnet_aux/leres/pix2pix/options/base_options.py b/controlnet_aux/leres/pix2pix/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..533a1e88a7e8494223f6994e6861c93667754f83 --- /dev/null +++ b/controlnet_aux/leres/pix2pix/options/base_options.py @@ -0,0 +1,156 @@ +import argparse +import os +from ...pix2pix.util import util +# import torch +from ...pix2pix import models +# import pix2pix.data +import numpy as np + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here') + # model parameters + parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') + parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + # dataset parameters + parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--load_size', type=int, default=672, help='scale images to this size') + parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size') + parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + parser.add_argument('--data_dir', type=str, required=False, + help='input files directory images can be .png .jpg .tiff') + parser.add_argument('--output_dir', type=str, required=False, + help='result dir. result depth will be png. vides are JMPG as avi') + parser.add_argument('--savecrops', type=int, required=False) + parser.add_argument('--savewholeest', type=int, required=False) + parser.add_argument('--output_resolution', type=int, required=False, + help='0 for no restriction 1 for resize to input size') + parser.add_argument('--net_receptive_field_size', type=int, required=False) + parser.add_argument('--pix2pixsize', type=int, required=False) + parser.add_argument('--generatevideo', type=int, required=False) + parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL') + parser.add_argument('--R0', action='store_true') + parser.add_argument('--R20', action='store_true') + parser.add_argument('--Final', action='store_true') + parser.add_argument('--colorize_results', action='store_true') + parser.add_argument('--max_res', type=float, default=np.inf) + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # modify dataset-related parser options + # dataset_name = opt.dataset_mode + # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name) + # parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + #return parser.parse_args() #EVIL + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + #self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + #if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt diff --git a/controlnet_aux/leres/pix2pix/options/test_options.py b/controlnet_aux/leres/pix2pix/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..a3424b5e3b66d6813f74c8cecad691d7488d121c --- /dev/null +++ b/controlnet_aux/leres/pix2pix/options/test_options.py @@ -0,0 +1,22 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') + # rewrite devalue values + parser.set_defaults(model='pix2pix4depth') + # To avoid cropping, the load_size should be the same as crop_size + parser.set_defaults(load_size=parser.get_default('crop_size')) + self.isTrain = False + return parser diff --git a/controlnet_aux/leres/pix2pix/util/__init__.py b/controlnet_aux/leres/pix2pix/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae36f63d8859ec0c60dcbfe67c4ac324e751ddf7 --- /dev/null +++ b/controlnet_aux/leres/pix2pix/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/controlnet_aux/leres/pix2pix/util/util.py b/controlnet_aux/leres/pix2pix/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8a7aceaa00681cb76675df7866bf8db58c8d2caf --- /dev/null +++ b/controlnet_aux/leres/pix2pix/util/util.py @@ -0,0 +1,105 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, imtype=np.uint16): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array + image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) # + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + + image_pil = image_pil.convert('I;16') + + # image_pil = Image.fromarray(image_numpy) + # h, w, _ = image_numpy.shape + # + # if aspect_ratio > 1.0: + # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + # if aspect_ratio < 1.0: + # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/controlnet_aux/lineart/LICENSE b/controlnet_aux/lineart/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520 --- /dev/null +++ b/controlnet_aux/lineart/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/controlnet_aux/lineart/__init__.py b/controlnet_aux/lineart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef20f32a3dcd51a8b305247a96288d455560461a --- /dev/null +++ b/controlnet_aux/lineart/__init__.py @@ -0,0 +1,167 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image + +norm_layer = nn.InstanceNorm2d + + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(Generator, self).__init__() + + # Initial convolution block + model0 = [ nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features*2 + for _ in range(2): + model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features*2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features//2 + for _ in range(2): + model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features//2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [ nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +class LineartDetector: + def __init__(self, model, coarse_model): + self.model = model + self.model_coarse = coarse_model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, coarse_filename=None, cache_dir=None, local_files_only=False): + filename = filename or "sk_model.pth" + coarse_filename = coarse_filename or "sk_model2.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + coarse_model_path = os.path.join(pretrained_model_or_path, coarse_filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + coarse_model_path = hf_hub_download(pretrained_model_or_path, coarse_filename, cache_dir=cache_dir, local_files_only=local_files_only) + + model = Generator(3, 1, 3) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + model.eval() + + coarse_model = Generator(3, 1, 3) + coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu'))) + coarse_model.eval() + + return cls(model, coarse_model) + + def to(self, device): + self.model.to(device) + self.model_coarse.to(device) + return self + + def __call__(self, input_image, coarse=False, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + model = self.model_coarse if coarse else self.model + assert input_image.ndim == 3 + image = input_image + with torch.no_grad(): + image = torch.from_numpy(image).float().to(device) + image = image / 255.0 + image = rearrange(image, 'h w c -> 1 c h w') + line = model(image)[0][0] + + line = line.cpu().numpy() + line = (line * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = line + + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + detected_map = 255 - detected_map + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/lineart_anime/LICENSE b/controlnet_aux/lineart_anime/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520 --- /dev/null +++ b/controlnet_aux/lineart_anime/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/controlnet_aux/lineart_anime/__init__.py b/controlnet_aux/lineart_anime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87b97d91842b8ec4e86a450ac61673282949665e --- /dev/null +++ b/controlnet_aux/lineart_anime/__init__.py @@ -0,0 +1,189 @@ +import functools +import os +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class LineartAnimeDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): + filename = filename or "netG.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) + ckpt = torch.load(model_path) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + net.load_state_dict(ckpt) + net.eval() + + return cls(net) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + H, W, C = input_image.shape + Hn = 256 * int(np.ceil(float(H) / 256.0)) + Wn = 256 * int(np.ceil(float(W) / 256.0)) + img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC) + with torch.no_grad(): + image_feed = torch.from_numpy(img).float().to(device) + image_feed = image_feed / 127.5 - 1.0 + image_feed = rearrange(image_feed, 'h w c -> 1 c h w') + + line = self.model(image_feed)[0, 0] * 127.5 + 127.5 + line = line.cpu().numpy() + + line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC) + line = line.clip(0, 255).astype(np.uint8) + + detected_map = line + + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + detected_map = 255 - detected_map + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/lineart_standard/__init__.py b/controlnet_aux/lineart_standard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48c6c7042d4a64f2b69135e59087fc26e6329313 --- /dev/null +++ b/controlnet_aux/lineart_standard/__init__.py @@ -0,0 +1,47 @@ +# Code based based from the repository comfyui_controlnet_aux: +# https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/controlnet_aux/lineart_standard/__init__.py +import cv2 +import numpy as np +from PIL import Image + +from ..util import HWC3, resize_image + + +class LineartStandardDetector: + def __call__( + self, + input_image=None, + guassian_sigma=6.0, + intensity_threshold=8, + detect_resolution=512, + output_type="pil", + ): + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + else: + output_type = output_type or "np" + + original_height, original_width, _ = input_image.shape + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + x = input_image.astype(np.float32) + g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) + intensity = np.min(g - x, axis=2).clip(0, 255) + intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) + intensity *= 127 + detected_map = intensity.clip(0, 255).astype(np.uint8) + + detected_map = HWC3(detected_map) + + detected_map = cv2.resize( + detected_map, + (original_width, original_height), + interpolation=cv2.INTER_CUBIC, + ) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/mediapipe_face/__init__.py b/controlnet_aux/mediapipe_face/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91f3cfc66832cb6acfc673c063cdc1b09496ff39 --- /dev/null +++ b/controlnet_aux/mediapipe_face/__init__.py @@ -0,0 +1,53 @@ +import warnings +from typing import Union + +import cv2 +import numpy as np +from PIL import Image + +from ..util import HWC3, resize_image +from .mediapipe_face_common import generate_annotation + + +class MediapipeFaceDetector: + def __call__(self, + input_image: Union[np.ndarray, Image.Image] = None, + max_faces: int = 1, + min_confidence: float = 0.5, + output_type: str = "pil", + detect_resolution: int = 512, + image_resolution: int = 512, + **kwargs): + + if "image" in kwargs: + warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("image") + if input_image is None: + raise ValueError("input_image must be defined.") + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + detected_map = generate_annotation(input_image, max_faces, min_confidence) + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/mediapipe_face/mediapipe_face_common.py b/controlnet_aux/mediapipe_face/mediapipe_face_common.py new file mode 100644 index 0000000000000000000000000000000000000000..76f6d32c6d8a5b561e0f10e77d193eff363ef0ba --- /dev/null +++ b/controlnet_aux/mediapipe_face/mediapipe_face_common.py @@ -0,0 +1,164 @@ +from typing import Mapping +import warnings + +try: + import mediapipe as mp +except ImportError: + warnings.warn( + "The module 'mediapipe' is not installed. The package will have limited functionality. Please install it using the command: pip install 'mediapipe'" + ) + + mp = None + +import numpy + +if mp: + mp_drawing = mp.solutions.drawing_utils + mp_drawing_styles = mp.solutions.drawing_styles + mp_face_detection = mp.solutions.face_detection # Only for counting faces. + mp_face_mesh = mp.solutions.face_mesh + mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION + mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS + mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS + + DrawingSpec = mp.solutions.drawing_styles.DrawingSpec + PoseLandmark = mp.solutions.drawing_styles.PoseLandmark + + min_face_size_pixels: int = 64 + f_thick = 2 + f_rad = 1 + right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) + right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) + right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) + left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) + left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) + left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) + mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad) + head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) + + # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. + face_connection_spec = {} + for edge in mp_face_mesh.FACEMESH_FACE_OVAL: + face_connection_spec[edge] = head_draw + for edge in mp_face_mesh.FACEMESH_LEFT_EYE: + face_connection_spec[edge] = left_eye_draw + for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: + face_connection_spec[edge] = left_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: + # face_connection_spec[edge] = left_iris_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: + face_connection_spec[edge] = right_eye_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: + face_connection_spec[edge] = right_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: + # face_connection_spec[edge] = right_iris_draw + for edge in mp_face_mesh.FACEMESH_LIPS: + face_connection_spec[edge] = mouth_draw + iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} + + +def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2): + """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all + landmarks. Until our PR is merged into mediapipe, we need this separate method.""" + if len(image.shape) != 3: + raise ValueError("Input image must be H,W,C.") + image_rows, image_cols, image_channels = image.shape + if image_channels != 3: # BGR channels + raise ValueError('Input image must contain three channel bgr data.') + for idx, landmark in enumerate(landmark_list.landmark): + if ( + (landmark.HasField('visibility') and landmark.visibility < 0.9) or + (landmark.HasField('presence') and landmark.presence < 0.5) + ): + continue + if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: + continue + image_x = int(image_cols*landmark.x) + image_y = int(image_rows*landmark.y) + draw_color = None + if isinstance(drawing_spec, Mapping): + if drawing_spec.get(idx) is None: + continue + else: + draw_color = drawing_spec[idx].color + elif isinstance(drawing_spec, DrawingSpec): + draw_color = drawing_spec.color + image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color + + +def reverse_channels(image): + """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB.""" + # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order. + # im[:,:,::[2,1,0]] would also work but makes a copy of the data. + return image[:, :, ::-1] + + +def generate_annotation( + img_rgb, + max_faces: int, + min_confidence: float +): + """ + Find up to 'max_faces' inside the provided input image. + If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many + pixels in the image. + """ + with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=max_faces, + refine_landmarks=True, + min_detection_confidence=min_confidence, + ) as facemesh: + img_height, img_width, img_channels = img_rgb.shape + assert(img_channels == 3) + + results = facemesh.process(img_rgb).multi_face_landmarks + + if results is None: + print("No faces detected in controlnet image for Mediapipe face annotator.") + return numpy.zeros_like(img_rgb) + + # Filter faces that are too small + filtered_landmarks = [] + for lm in results: + landmarks = lm.landmark + face_rect = [ + landmarks[0].x, + landmarks[0].y, + landmarks[0].x, + landmarks[0].y, + ] # Left, up, right, down. + for i in range(len(landmarks)): + face_rect[0] = min(face_rect[0], landmarks[i].x) + face_rect[1] = min(face_rect[1], landmarks[i].y) + face_rect[2] = max(face_rect[2], landmarks[i].x) + face_rect[3] = max(face_rect[3], landmarks[i].y) + if min_face_size_pixels > 0: + face_width = abs(face_rect[2] - face_rect[0]) + face_height = abs(face_rect[3] - face_rect[1]) + face_width_pixels = face_width * img_width + face_height_pixels = face_height * img_height + face_size = min(face_width_pixels, face_height_pixels) + if face_size >= min_face_size_pixels: + filtered_landmarks.append(lm) + else: + filtered_landmarks.append(lm) + + # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start. + empty = numpy.zeros_like(img_rgb) + + # Draw detected faces: + for face_landmarks in filtered_landmarks: + mp_drawing.draw_landmarks( + empty, + face_landmarks, + connections=face_connection_spec.keys(), + landmark_drawing_spec=None, + connection_drawing_spec=face_connection_spec + ) + draw_pupils(empty, face_landmarks, iris_landmark_spec, 2) + + # Flip BGR back to RGB. + empty = reverse_channels(empty).copy() + + return empty \ No newline at end of file diff --git a/controlnet_aux/midas/LICENSE b/controlnet_aux/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/controlnet_aux/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/controlnet_aux/midas/__init__.py b/controlnet_aux/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfab34112e65aeda9848a461fc4a1cc4d2107cf --- /dev/null +++ b/controlnet_aux/midas/__init__.py @@ -0,0 +1,95 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, model_type="dpt_hybrid", filename=None, cache_dir=None, local_files_only=False): + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + else: + filename = filename or "dpt_hybrid-midas-501f0c75.pt" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + model = MiDaSInference(model_type=model_type, model_path=model_path) + + return cls(model) + + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, image_resolution=512, output_type=None): + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float() + image_depth = image_depth.to(device) + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + if depth_and_normal: + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] + + depth_image = HWC3(depth_image) + if depth_and_normal: + normal_image = HWC3(normal_image) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + depth_image = cv2.resize(depth_image, (W, H), interpolation=cv2.INTER_LINEAR) + if depth_and_normal: + normal_image = cv2.resize(normal_image, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + depth_image = Image.fromarray(depth_image) + if depth_and_normal: + normal_image = Image.fromarray(normal_image) + + if depth_and_normal: + return depth_image, normal_image + else: + return depth_image diff --git a/controlnet_aux/midas/api.py b/controlnet_aux/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4cb4d6b3edb344e5d566da7f90037d163b5f21 --- /dev/null +++ b/controlnet_aux/midas/api.py @@ -0,0 +1,169 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from ..util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type, model_path=None): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = model_path or ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type, model_path): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type, model_path) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/controlnet_aux/midas/midas/__init__.py b/controlnet_aux/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/midas/midas/base_model.py b/controlnet_aux/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/controlnet_aux/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/controlnet_aux/midas/midas/blocks.py b/controlnet_aux/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/controlnet_aux/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/controlnet_aux/midas/midas/dpt_depth.py b/controlnet_aux/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/controlnet_aux/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/controlnet_aux/midas/midas/midas_net.py b/controlnet_aux/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/controlnet_aux/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/controlnet_aux/midas/midas/midas_net_custom.py b/controlnet_aux/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/controlnet_aux/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/controlnet_aux/midas/midas/transforms.py b/controlnet_aux/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/controlnet_aux/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/controlnet_aux/midas/midas/vit.py b/controlnet_aux/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/controlnet_aux/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/controlnet_aux/midas/utils.py b/controlnet_aux/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/controlnet_aux/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/controlnet_aux/mlsd/LICENSE b/controlnet_aux/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/controlnet_aux/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/controlnet_aux/mlsd/__init__.py b/controlnet_aux/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb7e650ee072648b449f874c071e90d96086664 --- /dev/null +++ b/controlnet_aux/mlsd/__init__.py @@ -0,0 +1,79 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + + +class MLSDdetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/mlsd_large_512_fp32.pth" + else: + filename = filename or "mlsd_large_512_fp32.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + + detected_map = img_output[:, :, 0] + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/mlsd/models/__init__.py b/controlnet_aux/mlsd/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/mlsd/models/mbv2_mlsd_large.py b/controlnet_aux/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/controlnet_aux/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py b/controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/controlnet_aux/mlsd/utils.py b/controlnet_aux/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28071cbf129a2bedb21a44f95d565aef7974e583 --- /dev/null +++ b/controlnet_aux/mlsd/utils.py @@ -0,0 +1,584 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + + device = next(iter(model.parameters())).device + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float() + batch_image = batch_image.to(device) + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + device = next(iter(model.parameters())).device + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to(device) + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/controlnet_aux/normalbae/LICENSE b/controlnet_aux/normalbae/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520 --- /dev/null +++ b/controlnet_aux/normalbae/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/controlnet_aux/normalbae/__init__.py b/controlnet_aux/normalbae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6011c3be08b418fd85481c1c0b8ce11987df0843 --- /dev/null +++ b/controlnet_aux/normalbae/__init__.py @@ -0,0 +1,109 @@ +import os +import types +import warnings + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from .nets.NNET import NNET + + +# load model +def load_checkpoint(fpath, model): + ckpt = torch.load(fpath, map_location='cpu')['model'] + + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + model.load_state_dict(load_dict) + return model + +class NormalBaeDetector: + def __init__(self, model): + self.model = model + self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): + filename = filename or "scannet.pt" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + args = types.SimpleNamespace() + args.mode = 'client' + args.architecture = 'BN' + args.pretrained = 'scannet' + args.sampling_ratio = 0.4 + args.importance_ratio = 0.7 + model = NNET(args) + model = load_checkpoint(model_path, model) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + image_normal = input_image + with torch.no_grad(): + image_normal = torch.from_numpy(image_normal).float().to(device) + image_normal = image_normal / 255.0 + image_normal = rearrange(image_normal, 'h w c -> 1 c h w') + image_normal = self.norm(image_normal) + + normal = self.model(image_normal) + normal = normal[0][-1][:, :3] + # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5 + # d = torch.maximum(d, torch.ones_like(d) * 1e-5) + # normal /= d + normal = ((normal + 1) * 0.5).clip(0, 1) + + normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy() + normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = normal_image + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + \ No newline at end of file diff --git a/controlnet_aux/normalbae/nets/NNET.py b/controlnet_aux/normalbae/nets/NNET.py new file mode 100644 index 0000000000000000000000000000000000000000..3ddbc50c3ac18aa4b7f16779fe3c0133981ecc7a --- /dev/null +++ b/controlnet_aux/normalbae/nets/NNET.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules.encoder import Encoder +from .submodules.decoder import Decoder + + +class NNET(nn.Module): + def __init__(self, args): + super(NNET, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder(args) + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + return self.decoder.parameters() + + def forward(self, img, **kwargs): + return self.decoder(self.encoder(img), **kwargs) \ No newline at end of file diff --git a/controlnet_aux/normalbae/nets/__init__.py b/controlnet_aux/normalbae/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/normalbae/nets/baseline.py b/controlnet_aux/normalbae/nets/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..602d0fbdac1acc9ede9bc1f2e10a5df78831ce9d --- /dev/null +++ b/controlnet_aux/normalbae/nets/baseline.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules.submodules import UpSampleBN, norm_normalize + + +# This is the baseline encoder-decoder we used in the ablation study +class NNET(nn.Module): + def __init__(self, args=None): + super(NNET, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder(num_classes=4) + + def forward(self, x, **kwargs): + out = self.decoder(self.encoder(x), **kwargs) + + # Bilinearly upsample the output to match the input resolution + up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False) + + # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa) + up_out = norm_normalize(up_out) + return up_out + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + modules = [self.decoder] + for m in modules: + yield from m.parameters() + + +# Encoder +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b5_ap' + basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) + + # Remove last layer + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +# Decoder (no pixel-wise MLP, no uncertainty-guided sampling) +class Decoder(nn.Module): + def __init__(self, num_classes=4): + super(Decoder, self).__init__() + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1) + + def forward(self, features): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + x_d0 = self.conv2(x_block4) + x_d1 = self.up1(x_d0, x_block3) + x_d2 = self.up2(x_d1, x_block2) + x_d3 = self.up3(x_d2, x_block1) + x_d4 = self.up4(x_d3, x_block0) + out = self.conv3(x_d4) + return out + + +if __name__ == '__main__': + model = Baseline() + x = torch.rand(2, 3, 480, 640) + out = model(x) + print(out.shape) diff --git a/controlnet_aux/normalbae/nets/submodules/__init__.py b/controlnet_aux/normalbae/nets/submodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/normalbae/nets/submodules/decoder.py b/controlnet_aux/normalbae/nets/submodules/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..993203d1792311f1c492091eaea3c1ac9088187f --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/decoder.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points + + +class Decoder(nn.Module): + def __init__(self, args): + super(Decoder, self).__init__() + + # hyper-parameter for sampling + self.sampling_ratio = args.sampling_ratio + self.importance_ratio = args.importance_ratio + + # feature-map + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + if args.architecture == 'BN': + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + + elif args.architecture == 'GN': + self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128) + + else: + raise Exception('invalid architecture') + + # produces 1/8 res output + self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + # produces 1/4 res output + self.out_conv_res4 = nn.Sequential( + nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/2 res output + self.out_conv_res2 = nn.Sequential( + nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/1 res output + self.out_conv_res1 = nn.Sequential( + nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + def forward(self, features, gt_norm_mask=None, mode='test'): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + + # generate feature-map + + x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res + x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res + x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res + x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res + x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res + + # 1/8 res output + out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output + out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output + + ################################################################################################################ + # out_res4 + ################################################################################################################ + + if mode == 'train': + # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160] + out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res8_res4.shape + + # samples: [B, 1, N, 2] + point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res4 = out_res8_res4 + + # grid_sample feature-map + feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N) + init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N) + samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized + + for i in range(B): + out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + # try all pixels + out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N) + out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized + out_res4 = out_res4.view(B, 4, H, W) + samples_pred_res4 = point_coords_res4 = None + + ################################################################################################################ + # out_res2 + ################################################################################################################ + + if mode == 'train': + + # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] + out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res4_res2.shape + + # samples: [B, 1, N, 2] + point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res2 = out_res4_res2 + + # grid_sample feature-map + feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N) + init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N) + samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized + + for i in range(B): + out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N) + out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized + out_res2 = out_res2.view(B, 4, H, W) + samples_pred_res2 = point_coords_res2 = None + + ################################################################################################################ + # out_res1 + ################################################################################################################ + + if mode == 'train': + # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] + out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res2_res1.shape + + # samples: [B, 1, N, 2] + point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res1 = out_res2_res1 + + # grid_sample feature-map + feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N) + init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N) + samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized + + for i in range(B): + out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N) + out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized + out_res1 = out_res1.view(B, 4, H, W) + samples_pred_res1 = point_coords_res1 = None + + return [out_res8, out_res4, out_res2, out_res1], \ + [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \ + [None, point_coords_res4, point_coords_res2, point_coords_res1] + diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/.gitignore b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f04e5fff91094d9b9c662bba977d762bf71516ac --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/.gitignore @@ -0,0 +1,109 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# pytorch stuff +*.pth +*.onnx +*.pb + +trained_models/ +.fuse_hidden* diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md new file mode 100644 index 0000000000000000000000000000000000000000..6ead7171ce5a5bbd2702f6b5c825dc9808ba5658 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md @@ -0,0 +1,555 @@ +# Model Performance Benchmarks + +All benchmarks run as per: + +``` +python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx +python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx +python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3 +python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt +python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb +python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb +``` + +## EfficientNet-B0 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897 +Time per operator type: + 29.7378 ms. 60.5145%. Conv + 12.1785 ms. 24.7824%. Sigmoid + 3.62811 ms. 7.38297%. SpatialBN + 2.98444 ms. 6.07314%. Mul + 0.326902 ms. 0.665225%. AveragePool + 0.197317 ms. 0.401528%. FC + 0.0852877 ms. 0.173555%. Add + 0.0032607 ms. 0.00663532%. Squeeze + 49.1416 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 95.2696%. Conv + 0.0269508 GFLOP. 3.33857%. SpatialBN + 0.00846444 GFLOP. 1.04855%. Mul + 0.002561 GFLOP. 0.317248%. FC + 0.000210112 GFLOP. 0.0260279%. Add + 0.807256 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 43.0891%. Mul + 43.2015 MB. 31.807%. Conv + 27.2869 MB. 20.0899%. SpatialBN + 5.12912 MB. 3.77631%. FC + 1.6809 MB. 1.23756%. Add + 135.824 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 38.1965%. Mul + 26.9881 MB. 30.4465%. Conv + 26.9508 MB. 30.4044%. SpatialBN + 0.840448 MB. 0.948147%. Add + 0.004 MB. 0.00451258%. FC + 88.6412 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 74.9391%. Conv + 5.124 MB. 24.265%. FC + 0.168064 MB. 0.795877%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 21.1168 MB in Total +``` +### Optimized +``` +Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996 +Time per operator type: + 29.776 ms. 65.002%. Conv + 12.2803 ms. 26.8084%. Sigmoid + 3.15073 ms. 6.87815%. Mul + 0.328651 ms. 0.717456%. AveragePool + 0.186237 ms. 0.406563%. FC + 0.0832429 ms. 0.181722%. Add + 0.0026184 ms. 0.00571606%. Squeeze + 45.8078 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 98.5601%. Conv + 0.00846444 GFLOP. 1.08476%. Mul + 0.002561 GFLOP. 0.328205%. FC + 0.000210112 GFLOP. 0.0269269%. Add + 0.780305 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 53.8803%. Mul + 43.2855 MB. 39.8501%. Conv + 5.12912 MB. 4.72204%. FC + 1.6809 MB. 1.54749%. Add + 108.621 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 54.8834%. Mul + 26.9881 MB. 43.7477%. Conv + 0.840448 MB. 1.36237%. Add + 0.004 MB. 0.00648399%. FC + 61.6904 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 75.5403%. Conv + 5.124 MB. 24.4597%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 20.9488 MB in Total +``` + +## EfficientNet-B1 +### Optimized +``` +Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256 +Time per operator type: + 45.7915 ms. 66.3206%. Conv + 17.8718 ms. 25.8841%. Sigmoid + 4.44132 ms. 6.43244%. Mul + 0.51001 ms. 0.738658%. AveragePool + 0.233283 ms. 0.337868%. Add + 0.194986 ms. 0.282402%. FC + 0.00268255 ms. 0.00388519%. Squeeze + 69.0456 ms in Total +FLOP per operator type: + 1.37105 GFLOP. 98.7673%. Conv + 0.0138759 GFLOP. 0.99959%. Mul + 0.002561 GFLOP. 0.184489%. FC + 0.000674432 GFLOP. 0.0485847%. Add + 1.38816 GFLOP in Total +Feature Memory Read per operator type: + 94.624 MB. 54.0789%. Mul + 69.8255 MB. 39.9062%. Conv + 5.39546 MB. 3.08357%. Add + 5.12912 MB. 2.93136%. FC + 174.974 MB in Total +Feature Memory Written per operator type: + 55.5035 MB. 54.555%. Mul + 43.5333 MB. 42.7894%. Conv + 2.69773 MB. 2.65163%. Add + 0.004 MB. 0.00393165%. FC + 101.739 MB in Total +Parameter Memory per operator type: + 25.7479 MB. 83.4024%. Conv + 5.124 MB. 16.5976%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 30.8719 MB in Total +``` + +## EfficientNet-B2 +### Optimized +``` +Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366 +Time per operator type: + 61.4627 ms. 67.5845%. Conv + 22.7458 ms. 25.0113%. Sigmoid + 5.59931 ms. 6.15701%. Mul + 0.642567 ms. 0.706568%. AveragePool + 0.272795 ms. 0.299965%. Add + 0.216178 ms. 0.237709%. FC + 0.00268895 ms. 0.00295677%. Squeeze + 90.942 ms in Total +FLOP per operator type: + 1.98431 GFLOP. 98.9343%. Conv + 0.0177039 GFLOP. 0.882686%. Mul + 0.002817 GFLOP. 0.140451%. FC + 0.000853984 GFLOP. 0.0425782%. Add + 2.00568 GFLOP in Total +Feature Memory Read per operator type: + 120.609 MB. 54.9637%. Mul + 86.3512 MB. 39.3519%. Conv + 6.83187 MB. 3.11341%. Add + 5.64163 MB. 2.571%. FC + 219.433 MB in Total +Feature Memory Written per operator type: + 70.8155 MB. 54.6573%. Mul + 55.3273 MB. 42.7031%. Conv + 3.41594 MB. 2.63651%. Add + 0.004 MB. 0.00308731%. FC + 129.563 MB in Total +Parameter Memory per operator type: + 30.4721 MB. 84.3913%. Conv + 5.636 MB. 15.6087%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 36.1081 MB in Total +``` + +## MixNet-M +### Optimized +``` +Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448 +Time per operator type: + 48.1139 ms. 75.2052%. Conv + 7.1341 ms. 11.1511%. Sigmoid + 2.63706 ms. 4.12189%. SpatialBN + 1.73186 ms. 2.70701%. Mul + 1.38707 ms. 2.16809%. Split + 1.29322 ms. 2.02139%. Concat + 1.00093 ms. 1.56452%. Relu + 0.235309 ms. 0.367803%. Add + 0.221579 ms. 0.346343%. FC + 0.219315 ms. 0.342803%. AveragePool + 0.00250145 ms. 0.00390993%. Squeeze + 63.9768 ms in Total +FLOP per operator type: + 0.675273 GFLOP. 95.5827%. Conv + 0.0221072 GFLOP. 3.12921%. SpatialBN + 0.00538445 GFLOP. 0.762152%. Mul + 0.003073 GFLOP. 0.434973%. FC + 0.000642488 GFLOP. 0.0909421%. Add + 0 GFLOP. 0%. Concat + 0 GFLOP. 0%. Relu + 0.70648 GFLOP in Total +Feature Memory Read per operator type: + 46.8424 MB. 30.502%. Conv + 36.8626 MB. 24.0036%. Mul + 22.3152 MB. 14.5309%. SpatialBN + 22.1074 MB. 14.3955%. Concat + 14.1496 MB. 9.21372%. Relu + 6.15414 MB. 4.00735%. FC + 5.1399 MB. 3.34692%. Add + 153.571 MB in Total +Feature Memory Written per operator type: + 32.7672 MB. 28.4331%. Conv + 22.1072 MB. 19.1831%. Concat + 22.1072 MB. 19.1831%. SpatialBN + 21.5378 MB. 18.689%. Mul + 14.1496 MB. 12.2781%. Relu + 2.56995 MB. 2.23003%. Add + 0.004 MB. 0.00347092%. FC + 115.243 MB in Total +Parameter Memory per operator type: + 13.7059 MB. 68.674%. Conv + 6.148 MB. 30.8049%. FC + 0.104 MB. 0.521097%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Concat + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 19.9579 MB in Total +``` + +## TF MobileNet-V3 Large 1.0 + +### Optimized +``` +Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525 +Time per operator type: + 17.437 ms. 80.0087%. Conv + 1.27662 ms. 5.8577%. Add + 1.12759 ms. 5.17387%. Div + 0.701155 ms. 3.21721%. Mul + 0.562654 ms. 2.58171%. Relu + 0.431144 ms. 1.97828%. Clip + 0.156902 ms. 0.719936%. FC + 0.0996858 ms. 0.457402%. AveragePool + 0.00112455 ms. 0.00515993%. Flatten + 21.7939 ms in Total +FLOP per operator type: + 0.43062 GFLOP. 98.1484%. Conv + 0.002561 GFLOP. 0.583713%. FC + 0.00210867 GFLOP. 0.480616%. Mul + 0.00193868 GFLOP. 0.441871%. Add + 0.00151532 GFLOP. 0.345377%. Div + 0 GFLOP. 0%. Relu + 0.438743 GFLOP in Total +Feature Memory Read per operator type: + 34.7967 MB. 43.9391%. Conv + 14.496 MB. 18.3046%. Mul + 9.44828 MB. 11.9307%. Add + 9.26157 MB. 11.6949%. Relu + 6.0614 MB. 7.65395%. Div + 5.12912 MB. 6.47673%. FC + 79.193 MB in Total +Feature Memory Written per operator type: + 17.6247 MB. 35.8656%. Conv + 9.26157 MB. 18.847%. Relu + 8.43469 MB. 17.1643%. Mul + 7.75472 MB. 15.7806%. Add + 6.06128 MB. 12.3345%. Div + 0.004 MB. 0.00813985%. FC + 49.1409 MB in Total +Parameter Memory per operator type: + 16.6851 MB. 76.5052%. Conv + 5.124 MB. 23.4948%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8091 MB in Total +``` + +## MobileNet-V3 (RW) + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712 +Time per operator type: + 15.9266 ms. 69.2624%. Conv + 2.36551 ms. 10.2873%. SpatialBN + 1.39102 ms. 6.04936%. Add + 1.30327 ms. 5.66773%. Div + 0.737014 ms. 3.20517%. Mul + 0.639697 ms. 2.78195%. Relu + 0.375681 ms. 1.63378%. Clip + 0.153126 ms. 0.665921%. FC + 0.0993787 ms. 0.432184%. AveragePool + 0.0032632 ms. 0.0141912%. Squeeze + 22.9946 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 94.4041%. Conv + 0.0175992 GFLOP. 3.85829%. SpatialBN + 0.002561 GFLOP. 0.561449%. FC + 0.00210961 GFLOP. 0.46249%. Mul + 0.00173891 GFLOP. 0.381223%. Add + 0.00151626 GFLOP. 0.33241%. Div + 0 GFLOP. 0%. Relu + 0.456141 GFLOP in Total +Feature Memory Read per operator type: + 34.7354 MB. 36.4363%. Conv + 17.7944 MB. 18.6658%. SpatialBN + 14.5035 MB. 15.2137%. Mul + 9.25778 MB. 9.71113%. Relu + 7.84641 MB. 8.23064%. Add + 6.06516 MB. 6.36216%. Div + 5.12912 MB. 5.38029%. FC + 95.3317 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 26.7264%. Conv + 17.5992 MB. 26.6878%. SpatialBN + 9.25778 MB. 14.0387%. Relu + 8.43843 MB. 12.7962%. Mul + 6.95565 MB. 10.5477%. Add + 6.06502 MB. 9.19713%. Div + 0.004 MB. 0.00606568%. FC + 65.9447 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.1564%. Conv + 5.124 MB. 23.3979%. FC + 0.0976 MB. 0.445674%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8994 MB in Total + +``` +### Optimized + +``` +Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527 +Time per operator type: + 17.146 ms. 78.8965%. Conv + 1.38453 ms. 6.37084%. Add + 1.30991 ms. 6.02749%. Div + 0.685417 ms. 3.15391%. Mul + 0.532589 ms. 2.45068%. Relu + 0.418263 ms. 1.92461%. Clip + 0.15128 ms. 0.696106%. FC + 0.102065 ms. 0.469648%. AveragePool + 0.0022143 ms. 0.010189%. Squeeze + 21.7323 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 98.1927%. Conv + 0.002561 GFLOP. 0.583981%. FC + 0.00210961 GFLOP. 0.481051%. Mul + 0.00173891 GFLOP. 0.396522%. Add + 0.00151626 GFLOP. 0.34575%. Div + 0 GFLOP. 0%. Relu + 0.438542 GFLOP in Total +Feature Memory Read per operator type: + 34.7842 MB. 44.833%. Conv + 14.5035 MB. 18.6934%. Mul + 9.25778 MB. 11.9323%. Relu + 7.84641 MB. 10.1132%. Add + 6.06516 MB. 7.81733%. Div + 5.12912 MB. 6.61087%. FC + 77.5861 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 36.4556%. Conv + 9.25778 MB. 19.1492%. Relu + 8.43843 MB. 17.4544%. Mul + 6.95565 MB. 14.3874%. Add + 6.06502 MB. 12.5452%. Div + 0.004 MB. 0.00827378%. FC + 48.3455 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.4973%. Conv + 5.124 MB. 23.5027%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8018 MB in Total + +``` + +## MnasNet-A1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345 +Time per operator type: + 24.4656 ms. 79.0905%. Conv + 4.14958 ms. 13.4144%. SpatialBN + 1.60598 ms. 5.19169%. Relu + 0.295219 ms. 0.95436%. Mul + 0.187609 ms. 0.606486%. FC + 0.120556 ms. 0.389724%. AveragePool + 0.09036 ms. 0.292109%. Add + 0.015727 ms. 0.050841%. Sigmoid + 0.00306205 ms. 0.00989875%. Squeeze + 30.9337 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 95.6434%. Conv + 0.0248873 GFLOP. 3.8355%. SpatialBN + 0.002561 GFLOP. 0.394688%. FC + 0.000597408 GFLOP. 0.0920695%. Mul + 0.000222656 GFLOP. 0.0343146%. Add + 0 GFLOP. 0%. Relu + 0.648867 GFLOP in Total +Feature Memory Read per operator type: + 35.5457 MB. 38.4109%. Conv + 25.1552 MB. 27.1829%. SpatialBN + 22.5235 MB. 24.339%. Relu + 5.12912 MB. 5.54256%. FC + 2.40586 MB. 2.59978%. Mul + 1.78125 MB. 1.92483%. Add + 92.5406 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 32.9424%. Conv + 24.8873 MB. 32.92%. SpatialBN + 22.5235 MB. 29.7932%. Relu + 2.38963 MB. 3.16092%. Mul + 0.890624 MB. 1.17809%. Add + 0.004 MB. 0.00529106%. FC + 75.5993 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.1459%. Conv + 5.124 MB. 32.9917%. FC + 0.133952 MB. 0.86247%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.5312 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597 +Time per operator type: + 22.0547 ms. 91.1375%. Conv + 1.49096 ms. 6.16116%. Relu + 0.253417 ms. 1.0472%. Mul + 0.18506 ms. 0.76473%. FC + 0.112942 ms. 0.466717%. AveragePool + 0.086769 ms. 0.358559%. Add + 0.0127889 ms. 0.0528479%. Sigmoid + 0.0027346 ms. 0.0113003%. Squeeze + 24.1994 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 99.4581%. Conv + 0.002561 GFLOP. 0.41043%. FC + 0.000597408 GFLOP. 0.0957417%. Mul + 0.000222656 GFLOP. 0.0356832%. Add + 0 GFLOP. 0%. Relu + 0.623979 GFLOP in Total +Feature Memory Read per operator type: + 35.6127 MB. 52.7968%. Conv + 22.5235 MB. 33.3917%. Relu + 5.12912 MB. 7.60406%. FC + 2.40586 MB. 3.56675%. Mul + 1.78125 MB. 2.64075%. Add + 67.4524 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 49.1092%. Conv + 22.5235 MB. 44.4145%. Relu + 2.38963 MB. 4.71216%. Mul + 0.890624 MB. 1.75624%. Add + 0.004 MB. 0.00788768%. FC + 50.712 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.7213%. Conv + 5.124 MB. 33.2787%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.3972 MB in Total +``` +## MnasNet-B1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322 +Time per operator type: + 29.1121 ms. 83.3081%. Conv + 4.14959 ms. 11.8746%. SpatialBN + 1.35823 ms. 3.88675%. Relu + 0.186188 ms. 0.532802%. FC + 0.116244 ms. 0.332647%. Add + 0.018641 ms. 0.0533437%. AveragePool + 0.0040904 ms. 0.0117052%. Squeeze + 34.9451 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 96.2088%. Conv + 0.0218266 GFLOP. 3.35303%. SpatialBN + 0.002561 GFLOP. 0.393424%. FC + 0.000291648 GFLOP. 0.0448034%. Add + 0 GFLOP. 0%. Relu + 0.650951 GFLOP in Total +Feature Memory Read per operator type: + 34.4354 MB. 41.3788%. Conv + 22.1299 MB. 26.5921%. SpatialBN + 19.1923 MB. 23.0622%. Relu + 5.12912 MB. 6.16333%. FC + 2.33318 MB. 2.80364%. Add + 83.2199 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 34.0955%. Conv + 21.8266 MB. 34.0955%. SpatialBN + 19.1923 MB. 29.9805%. Relu + 1.16659 MB. 1.82234%. Add + 0.004 MB. 0.00624844%. FC + 64.016 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 69.9104%. Conv + 5.124 MB. 29.2245%. FC + 0.15168 MB. 0.865099%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.5332 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426 +Time per operator type: + 24.9888 ms. 94.0962%. Conv + 1.26147 ms. 4.75011%. Relu + 0.176234 ms. 0.663619%. FC + 0.113309 ms. 0.426672%. Add + 0.0138708 ms. 0.0522311%. AveragePool + 0.00295685 ms. 0.0111341%. Squeeze + 26.5566 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 99.5466%. Conv + 0.002561 GFLOP. 0.407074%. FC + 0.000291648 GFLOP. 0.0463578%. Add + 0 GFLOP. 0%. Relu + 0.629124 GFLOP in Total +Feature Memory Read per operator type: + 34.5112 MB. 56.4224%. Conv + 19.1923 MB. 31.3775%. Relu + 5.12912 MB. 8.3856%. FC + 2.33318 MB. 3.81452%. Add + 61.1658 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 51.7346%. Conv + 19.1923 MB. 45.4908%. Relu + 1.16659 MB. 2.76513%. Add + 0.004 MB. 0.00948104%. FC + 42.1895 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 70.5205%. Conv + 5.124 MB. 29.4795%. FC + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.3816 MB in Total +``` diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/LICENSE b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..80e7d15508202f3262a50db27f5198460d7f509f --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Ross Wightman + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/README.md b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..463368280d6a5015060eb73d20fe6512f8e04c50 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/README.md @@ -0,0 +1,323 @@ +# (Generic) EfficientNets for PyTorch + +A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search. + +All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py)) + +## What's New + +### Aug 19, 2020 +* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1) +* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1) +* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX +* ONNX runtime based validation script added +* activations (mostly) brought in sync with `timm` equivalents + + +### April 5, 2020 +* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite + * 3.5M param MobileNet-V2 100 @ 73% + * 4.5M param MobileNet-V2 110d @ 75% + * 6.1M param MobileNet-V2 140 @ 76.5% + * 5.8M param MobileNet-V2 120d @ 77.3% + +### March 23, 2020 + * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1 + * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior + +### Feb 12, 2020 + * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) + * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization. + * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) + +### Jan 22, 2020 + * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models) + * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict + * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues + +### Nov 22, 2019 + * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different + preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights. + +### Nov 15, 2019 + * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights + * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine + +### Oct 30, 2019 + * Many of the models will now work with torch.jit.script, MixNet being the biggest exception + * Improved interface for enabling torchscript or ONNX export compatible modes (via config) + * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn + * Activation factory to select best version of activation by name or override one globally + * Add pretrained checkpoint load helper that handles input conv and classifier changes + +### Oct 27, 2019 + * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base + * Switch activations and global pooling to modules + * Add memory-efficient Swish/Mish impl + * Add as_sequential() method to all models and allow as an argument in entrypoint fns + * Move MobileNetV3 into own file since it has a different head + * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here + +## Models + +Implemented models include: + * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252) + * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) + * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946) + * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) + * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971) + * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * MixNet (https://arxiv.org/abs/1907.09595) + * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) + * MobileNet-V3 (https://arxiv.org/abs/1905.02244) + * FBNet-C (https://arxiv.org/abs/1812.03443) + * Single-Path NAS (https://arxiv.org/abs/1904.02877) + +I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code. + +## Pretrained + +I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models + + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop | +|---|---|---|---|---|---|---|---| +| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 | +| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 | +| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 | +| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 | +| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 | +| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 | +| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 | +| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 | +| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 | +| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 | +| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 | +| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 | +| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 | +| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 | +| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 | +| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 | +| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 | +| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 | + + +More pretrained models to come... + + +## Ported Weights + +The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args. + +**IMPORTANT:** +* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std. +* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl. + +To run validation for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic` + +To run validation w/ TF preprocessing for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing` + +To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5` + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop | +|---|---|---|---|---|---|---| +| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A | +| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 | +| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 | +| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A | +| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 | +| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 | +| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 | +| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A | +| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 | +| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A | +| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 | +| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A | +| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A | +| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A | +| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A | +| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 | +| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 | +| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 | +| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A | +| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 | +| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 | +| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A | +| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 | +| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A | +| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 | +| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A | +| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 | +| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A | +| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A | +| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 | + + +*tfp models validated with `tf-preprocessing` pipeline + +Google tf and tflite weights ported from official Tensorflow repositories +* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet +* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet +* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet + +## Usage + +### Environment + +All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x. + +Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself. + +PyTorch versions 1.4, 1.5, 1.6 have been tested with this code. + +I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: +``` +conda create -n torch-env +conda activate torch-env +conda install -c pytorch pytorch torchvision cudatoolkit=10.2 +``` + +### PyTorch Hub + +Models can be accessed via the PyTorch Hub API + +``` +>>> torch.hub.list('rwightman/gen-efficientnet-pytorch') +['efficientnet_b0', ...] +>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True) +>>> model.eval() +>>> output = model(torch.randn(1,3,224,224)) +``` + +### Pip +This package can be installed via pip. + +Install (after conda env/install): +``` +pip install geffnet +``` + +Eval use: +``` +>>> import geffnet +>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True) +>>> m.eval() +``` + +Train use: +``` +>>> import geffnet +>>> # models can also be created by using the entrypoint directly +>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2) +>>> m.train() +``` + +Create in a nn.Sequential container, for fast.ai, etc: +``` +>>> import geffnet +>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True) +``` + +### Exporting + +Scripts are included to +* export models to ONNX (`onnx_export.py`) +* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg) +* validate with ONNX runtime (`onnx_validate.py`) +* convert ONNX model to Caffe2 (`onnx_to_caffe.py`) +* validate in Caffe2 (`caffe2_validate.py`) +* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`) + +As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation: +``` +python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx +python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx +``` + +These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible +export now requires additional args mentioned in the export script (not needed in earlier versions). + +#### Export Notes +1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script. +2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working. +3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization. +3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here. + + diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/__init__.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..93f28a1e63d9f7287ca02997c7991fe66dd0aeb9 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py @@ -0,0 +1,65 @@ +""" Caffe2 validation script + +This script runs Caffe2 benchmark on exported ONNX model. +It is a useful tool for reporting model FLOPS. + +Copyright 2020 Ross Wightman +""" +import argparse +from caffe2.python import core, workspace, model_helper +from caffe2.proto import caffe2_pb2 + + +parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') +parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', + help='caffe2 model pb name prefix') +parser.add_argument('--c2-init', default='', type=str, metavar='PATH', + help='caffe2 model init .pb') +parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', + help='caffe2 model predict .pb') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=224, type=int, + metavar='N', help='Input image dimension, uses model default if empty') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + if args.c2_prefix: + args.c2_init = args.c2_prefix + '.init.pb' + args.c2_predict = args.c2_prefix + '.predict.pb' + + model = model_helper.ModelHelper(name="le_net", init_params=False) + + # Bring in the init net from init_net.pb + init_net_proto = caffe2_pb2.NetDef() + with open(args.c2_init, "rb") as f: + init_net_proto.ParseFromString(f.read()) + model.param_init_net = core.Net(init_net_proto) + + # bring in the predict net from predict_net.pb + predict_net_proto = caffe2_pb2.NetDef() + with open(args.c2_predict, "rb") as f: + predict_net_proto.ParseFromString(f.read()) + model.net = core.Net(predict_net_proto) + + # CUDA performance not impressive + #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) + #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + + input_blob = model.net.external_inputs[0] + model.param_init_net.GaussianFill( + [], + input_blob.GetUnscopedName(), + shape=(args.batch_size, 3, args.img_size, args.img_size), + mean=0.0, + std=1.0) + workspace.RunNetOnce(model.param_init_net) + workspace.CreateNet(model.net, overwrite=True) + workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) + + +if __name__ == '__main__': + main() diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfaab38c095663fe32e4addbdf06b57bcb53614 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py @@ -0,0 +1,138 @@ +""" Caffe2 validation script + +This script is created to verify exported ONNX models running in Caffe2 +It utilizes the same PyTorch dataloader/processing pipeline for a +fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +from caffe2.python import core, workspace, model_helper +from caffe2.proto import caffe2_pb2 +from data import create_loader, resolve_data_config, Dataset +from utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', + help='caffe2 model pb name prefix') +parser.add_argument('--c2-init', default='', type=str, metavar='PATH', + help='caffe2 model init .pb') +parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', + help='caffe2 model predict .pb') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + if args.c2_prefix: + args.c2_init = args.c2_prefix + '.init.pb' + args.c2_predict = args.c2_prefix + '.predict.pb' + + model = model_helper.ModelHelper(name="validation_net", init_params=False) + + # Bring in the init net from init_net.pb + init_net_proto = caffe2_pb2.NetDef() + with open(args.c2_init, "rb") as f: + init_net_proto.ParseFromString(f.read()) + model.param_init_net = core.Net(init_net_proto) + + # bring in the predict net from predict_net.pb + predict_net_proto = caffe2_pb2.NetDef() + with open(args.c2_predict, "rb") as f: + predict_net_proto.ParseFromString(f.read()) + model.net = core.Net(predict_net_proto) + + data_config = resolve_data_config(None, args) + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + # this is so obvious, wonderful interface + input_blob = model.net.external_inputs[0] + output_blob = model.net.external_outputs[0] + + if True: + device_opts = None + else: + # CUDA is crashing, no idea why, awesome error message, give it a try for kicks + device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) + model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + + model.param_init_net.GaussianFill( + [], input_blob.GetUnscopedName(), + shape=(1,) + data_config['input_size'], mean=0.0, std=1.0) + workspace.RunNetOnce(model.param_init_net) + workspace.CreateNet(model.net, overwrite=True) + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + caffe2_in = input.data.numpy() + workspace.FeedBlob(input_blob, caffe2_in, device_opts) + workspace.RunNet(model.net, num_iter=1) + output = workspace.FetchBlob(output_blob) + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output.data, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, + ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e441a5838d1e972823b9668ac8d459445f6f6ce --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py @@ -0,0 +1,5 @@ +from .gen_efficientnet import * +from .mobilenetv3 import * +from .model_factory import create_model +from .config import is_exportable, is_scriptable, set_exportable, set_scriptable +from .activations import * \ No newline at end of file diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..813421a743ffc33b8eb53ebf62dd4a03d831b654 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py @@ -0,0 +1,137 @@ +from geffnet import config +from geffnet.activations.activations_me import * +from geffnet.activations.activations_jit import * +from geffnet.activations.activations import * +import torch + +_has_silu = 'silu' in dir(torch.nn.functional) + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=mish_jit, +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=mish_me, + hard_swish=hard_swish_me, + hard_sigmoid_jit=hard_sigmoid_me, +) + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=MishJit, +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=MishMe, + hard_swish=HardSwishMe, + hard_sigmoid=HardSigmoidMe +) + +_OVERRIDE_FN = dict() +_OVERRIDE_LAYER = dict() + + +def add_override_act_fn(name, fn): + global _OVERRIDE_FN + _OVERRIDE_FN[name] = fn + + +def update_override_act_fn(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_FN + _OVERRIDE_FN.update(overrides) + + +def clear_override_act_fn(): + global _OVERRIDE_FN + _OVERRIDE_FN = dict() + + +def add_override_act_layer(name, fn): + _OVERRIDE_LAYER[name] = fn + + +def update_override_act_layer(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_LAYER + _OVERRIDE_LAYER.update(overrides) + + +def clear_override_act_layer(): + global _OVERRIDE_LAYER + _OVERRIDE_LAYER = dict() + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_FN: + return _OVERRIDE_FN[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_FN_ME: + # If not exporting or scripting the model, first look for a memory optimized version + # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin + return _ACT_FN_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_LAYER: + return _OVERRIDE_LAYER[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..bdea692d1397673b2513d898c33edbcb37d94240 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py @@ -0,0 +1,102 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..7176b05e779787528a47f20d55d64d4a0f219360 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py @@ -0,0 +1,79 @@ +""" Activations (jit) + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + +__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', + 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..e91df5a50fdbe40bc386e2541a4fda743ad95e9a --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py @@ -0,0 +1,174 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', + 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + + Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..27d5307fd9ee0246f1e35f41520f17385d23f1dd --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py @@ -0,0 +1,123 @@ +""" Global layer config state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def layer_config_kwargs(kwargs): + """ Consume config kwargs and return contextmgr obj """ + return set_layer_config( + scriptable=kwargs.pop('scriptable', None), + exportable=kwargs.pop('exportable', None), + no_jit=kwargs.pop('no_jit', None)) diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d8467460c4b36e54c83ce2dcd3ebe91d3432cad2 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py @@ -0,0 +1,304 @@ +""" Conv2D w/ SAME padding, CondConv, MixedConv + +A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and +MobileNetV3 models that maintain weight compatibility with original Tensorflow models. + +Copyright 2020 Ross Wightman +""" +import collections.abc +import math +from functools import partial +from itertools import repeat +from typing import Tuple, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import * + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _same_pad_arg(input_size, kernel_size, stride, dilation): + ih, iw = input_size + kh, kw = kernel_size + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dSameExport(nn.Conv2d): + """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions + + NOTE: This does not currently work with torch.jit.script + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSameExport, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.pad = None + self.pad_input_size = (0, 0) + + def forward(self, x): + input_size = x.size()[-2:] + if self.pad is None: + pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) + self.pad = nn.ZeroPad2d(pad_arg) + self.pad_input_size = input_size + + if self.pad is not None: + x = self.pad(x) + return F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic padding + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + if is_exportable(): + assert not is_scriptable() + return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) + else: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..95dd63d400e70d70664c5a433a2772363f865e61 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py @@ -0,0 +1,683 @@ +""" EfficientNet / MobileNetV3 Blocks and Builder + +Copyright 2020 Ross Wightman +""" +import re +from copy import deepcopy + +from .conv2d_layers import * +from geffnet.activations import * + +__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', + 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', + 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', + 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' +] + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +# +# PyTorch defaults are momentum = .1, eps = 1e-5 +# +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, # None == use containing block's activation layer + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def make_divisible(v: int, divisor: int = 8, min_value: int = None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class SqueezeExcite(nn.Module): + + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): + super(SqueezeExcite, self).__init__() + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion + factor of 1.0. This is an alternative to having a IR with optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs: int = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() # for jit.script compat + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) + self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class EfficientNetBuilder: + """ Build Trunk Blocks for Efficient/Mobile Networks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + + # updated during build + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = EdgeResidual(**ba) + elif bt == 'cn': + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 # incr global idx (across all stacks) + return nn.Sequential(*blocks) + + def __call__(self, in_chs, block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stack_idx, stack in enumerate(block_args): + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return blocks + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +def initialize_weight_goog(m, n='', fix_group_fanout=True): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def initialize_weight_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cd170d4cc5bed6ca82b61539902b470d3320c691 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py @@ -0,0 +1,1450 @@ +""" Generic Efficient Networks + +A generic MobileNet class with building blocks to support a variety of models: + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* EfficientNet-Lite + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .config import layer_config_kwargs, is_scriptable +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', + 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', + 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d', + 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', + 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', + 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', + 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', + 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', + 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', + 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', + 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', + 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', + 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', + 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', + 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', + 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', + 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', + 'tf_efficientnet_lite4', + 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] + + +model_urls = { + 'mnasnet_050': None, + 'mnasnet_075': None, + 'mnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + 'mnasnet_140': None, + 'mnasnet_small': None, + + 'semnasnet_050': None, + 'semnasnet_075': None, + 'semnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + 'semnasnet_140': None, + + 'mobilenetv2_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + 'mobilenetv2_110d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + 'mobilenetv2_120d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + 'mobilenetv2_140': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + + 'fbnetc_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + 'spnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + + 'efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + 'efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + 'efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + 'efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + 'efficientnet_b4': None, + 'efficientnet_b5': None, + 'efficientnet_b6': None, + 'efficientnet_b7': None, + 'efficientnet_b8': None, + 'efficientnet_l2': None, + + 'efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + 'efficientnet_em': None, + 'efficientnet_el': None, + + 'efficientnet_cc_b0_4e': None, + 'efficientnet_cc_b0_8e': None, + 'efficientnet_cc_b1_8e': None, + + 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + 'efficientnet_lite1': None, + 'efficientnet_lite2': None, + 'efficientnet_lite3': None, + 'efficientnet_lite4': None, + + 'tf_efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + 'tf_efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + 'tf_efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + 'tf_efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + 'tf_efficientnet_b4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + 'tf_efficientnet_b5': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + 'tf_efficientnet_b6': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + 'tf_efficientnet_b7': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + 'tf_efficientnet_b8': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + + 'tf_efficientnet_b0_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + 'tf_efficientnet_b1_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + 'tf_efficientnet_b2_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + 'tf_efficientnet_b3_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + 'tf_efficientnet_b4_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + 'tf_efficientnet_b5_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + 'tf_efficientnet_b6_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + 'tf_efficientnet_b7_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + 'tf_efficientnet_b8_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + + 'tf_efficientnet_b0_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + 'tf_efficientnet_b1_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + 'tf_efficientnet_b2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + 'tf_efficientnet_b3_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + 'tf_efficientnet_b4_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + 'tf_efficientnet_b5_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + 'tf_efficientnet_b6_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + 'tf_efficientnet_b7_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + 'tf_efficientnet_l2_ns_475': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + 'tf_efficientnet_l2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + + 'tf_efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + 'tf_efficientnet_em': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + 'tf_efficientnet_el': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + + 'tf_efficientnet_cc_b0_4e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + 'tf_efficientnet_cc_b0_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + 'tf_efficientnet_cc_b1_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + + 'tf_efficientnet_lite0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + 'tf_efficientnet_lite1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + 'tf_efficientnet_lite2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + 'tf_efficientnet_lite3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + 'tf_efficientnet_lite4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + + 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + + 'tf_mixnet_s': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + 'tf_mixnet_m': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + 'tf_mixnet_l': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', +} + + +class GenEfficientNet(nn.Module): + """ Generic EfficientNets + + An implementation of mobile optimized networks that covers: + * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) + * MixNet (Small, Medium, and Large, XL) + * MNASNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + weight_init='goog'): + super(GenEfficientNet, self).__init__() + self.drop_rate = drop_rate + + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, + pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) + self.bn2 = norm_layer(num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(num_features, num_classes) + + for n, m in self.named_modules(): + if weight_init == 'goog': + initialize_weight_goog(m, n) + else: + initialize_weight_default(m, n) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.conv_head, self.bn2, self.act2, + self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.features(x) + x = self.global_pool(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = GenEfficientNet(**model_kwargs) + if pretrained: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + fix_stem=fix_stem_head, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU6, + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an efficientnet-condconv model.""" + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=nn.ReLU6, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2. """ + # NOTE for train, drop_rate should be 0.5 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + # NOTE for train set drop_rate=0.2 + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f83a07d690c7ad681c777c19b1e7a5bb95da007 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py @@ -0,0 +1,71 @@ +""" Checkpoint loading / state_dict helpers +Copyright 2020 Ross Wightman +""" +import torch +import os +from collections import OrderedDict +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + else: + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_pretrained(model, url, filter_fn=None, strict=True): + if not url: + print("=> Warning: Pretrained model URL is empty, using random initialization.") + return + + state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') + + input_conv = 'conv_stem' + classifier = 'classifier' + in_chans = getattr(model, input_conv).weight.shape[1] + num_classes = getattr(model, classifier).weight.shape[0] + + input_conv_weight = input_conv + '.weight' + pretrained_in_chans = state_dict[input_conv_weight].shape[1] + if in_chans != pretrained_in_chans: + if in_chans == 1: + print('=> Converting pretrained input conv {} from {} to 1 channel'.format( + input_conv_weight, pretrained_in_chans)) + conv1_weight = state_dict[input_conv_weight] + state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) + else: + print('=> Discarding pretrained input conv {} since input channel count != {}'.format( + input_conv_weight, pretrained_in_chans)) + del state_dict[input_conv_weight] + strict = False + + classifier_weight = classifier + '.weight' + pretrained_num_classes = state_dict[classifier_weight].shape[0] + if num_classes != pretrained_num_classes: + print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) + del state_dict[classifier_weight] + del state_dict[classifier + '.bias'] + strict = False + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..b5966c28f7207e98ee50745b1bc8f3663c650f9d --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py @@ -0,0 +1,364 @@ +""" MobileNet-V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_act_fn, get_act_layer, HardSwish +from .config import layer_config_kwargs +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', + 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', + 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', + 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] + +model_urls = { + 'mobilenetv3_rw': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + 'mobilenetv3_large_075': None, + 'mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + 'mobilenetv3_large_minimal_100': None, + 'mobilenetv3_small_075': None, + 'mobilenetv3_small_100': None, + 'mobilenetv3_small_minimal_100': None, + 'tf_mobilenetv3_large_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + 'tf_mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + 'tf_mobilenetv3_large_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + 'tf_mobilenetv3_small_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + 'tf_mobilenetv3_small_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + 'tf_mobilenetv3_small_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', +} + + +class MobileNetV3(nn.Module): + """ MobileNet-V3 + + A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the + head convolution without a final batch-norm layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3, self).__init__() + self.drop_rate = drop_rate + + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.classifier = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if weight_init == 'goog': + initialize_weight_goog(m) + else: + initialize_weight_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.global_pool, self.conv_head, self.act2, + nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = MobileNetV3(**model_kwargs) + if pretrained and model_urls[variant]: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model (RW variant). + + Paper: https://arxiv.org/abs/1905.02244 + + This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the + eventual Tensorflow reference impl but has a few differences: + 1. This model has no bias on the head convolution + 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet + 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer + from their parent block + 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count + + Overall the changes are fairly minor and result in a very small parameter count difference and no + top-1/5 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, # one of my mistakes + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 large/small/minimal models. + + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, act_layer), + se_kwargs=dict( + act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet-V3 RW + Attn: See note in gen function for this variant. + """ + # NOTE for train set drop_rate=0.2 + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75""" + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large (Minimalistic) 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small (Minimalistic) 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4d46ea8baedaf3d787826eb3bb314b4230514647 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py @@ -0,0 +1,27 @@ +from .config import set_layer_config +from .helpers import load_checkpoint + +from .gen_efficientnet import * +from .mobilenetv3 import * + + +def create_model( + model_name='mnasnet_100', + pretrained=None, + num_classes=1000, + in_chans=3, + checkpoint_path='', + **kwargs): + + model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) + + if model_name in globals(): + create_fn = globals()[model_name] + model = create_fn(**model_kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path and not pretrained: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a6221b3de7b1490c5e712e8b5fcc94c3d9d04295 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py @@ -0,0 +1 @@ +__version__ = '1.0.2' diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..45b17b99bbeba34596569e6e50f6e8a2ebc45c54 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py @@ -0,0 +1,84 @@ +dependencies = ['torch', 'math'] + +from geffnet import efficientnet_b0 +from geffnet import efficientnet_b1 +from geffnet import efficientnet_b2 +from geffnet import efficientnet_b3 + +from geffnet import efficientnet_es + +from geffnet import efficientnet_lite0 + +from geffnet import mixnet_s +from geffnet import mixnet_m +from geffnet import mixnet_l +from geffnet import mixnet_xl + +from geffnet import mobilenetv2_100 +from geffnet import mobilenetv2_110d +from geffnet import mobilenetv2_120d +from geffnet import mobilenetv2_140 + +from geffnet import mobilenetv3_large_100 +from geffnet import mobilenetv3_rw +from geffnet import mnasnet_a1 +from geffnet import mnasnet_b1 +from geffnet import fbnetc_100 +from geffnet import spnasnet_100 + +from geffnet import tf_efficientnet_b0 +from geffnet import tf_efficientnet_b1 +from geffnet import tf_efficientnet_b2 +from geffnet import tf_efficientnet_b3 +from geffnet import tf_efficientnet_b4 +from geffnet import tf_efficientnet_b5 +from geffnet import tf_efficientnet_b6 +from geffnet import tf_efficientnet_b7 +from geffnet import tf_efficientnet_b8 + +from geffnet import tf_efficientnet_b0_ap +from geffnet import tf_efficientnet_b1_ap +from geffnet import tf_efficientnet_b2_ap +from geffnet import tf_efficientnet_b3_ap +from geffnet import tf_efficientnet_b4_ap +from geffnet import tf_efficientnet_b5_ap +from geffnet import tf_efficientnet_b6_ap +from geffnet import tf_efficientnet_b7_ap +from geffnet import tf_efficientnet_b8_ap + +from geffnet import tf_efficientnet_b0_ns +from geffnet import tf_efficientnet_b1_ns +from geffnet import tf_efficientnet_b2_ns +from geffnet import tf_efficientnet_b3_ns +from geffnet import tf_efficientnet_b4_ns +from geffnet import tf_efficientnet_b5_ns +from geffnet import tf_efficientnet_b6_ns +from geffnet import tf_efficientnet_b7_ns +from geffnet import tf_efficientnet_l2_ns_475 +from geffnet import tf_efficientnet_l2_ns + +from geffnet import tf_efficientnet_es +from geffnet import tf_efficientnet_em +from geffnet import tf_efficientnet_el + +from geffnet import tf_efficientnet_cc_b0_4e +from geffnet import tf_efficientnet_cc_b0_8e +from geffnet import tf_efficientnet_cc_b1_8e + +from geffnet import tf_efficientnet_lite0 +from geffnet import tf_efficientnet_lite1 +from geffnet import tf_efficientnet_lite2 +from geffnet import tf_efficientnet_lite3 +from geffnet import tf_efficientnet_lite4 + +from geffnet import tf_mixnet_s +from geffnet import tf_mixnet_m +from geffnet import tf_mixnet_l + +from geffnet import tf_mobilenetv3_large_075 +from geffnet import tf_mobilenetv3_large_100 +from geffnet import tf_mobilenetv3_large_minimal_100 +from geffnet import tf_mobilenetv3_small_075 +from geffnet import tf_mobilenetv3_small_100 +from geffnet import tf_mobilenetv3_small_minimal_100 + diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5162ce214830df501bdb81edb66c095122f69d --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py @@ -0,0 +1,120 @@ +""" ONNX export script + +Export PyTorch models as ONNX graphs. + +This export script originally started as an adaptation of code snippets found at +https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html + +The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph +for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible +with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback +flags are currently required. + +Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for +caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime. + +Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models. +Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks. + +Copyright 2020 Ross Wightman +""" +import argparse +import torch +import numpy as np + +import onnx +import geffnet + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('output', metavar='ONNX_FILE', + help='output model filename') +parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100', + help='model architecture (default: mobilenetv3_large_100)') +parser.add_argument('--opset', type=int, default=10, + help='ONNX opset to use (default: 10)') +parser.add_argument('--keep-init', action='store_true', default=False, + help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') +parser.add_argument('--aten-fallback', action='store_true', default=False, + help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') +parser.add_argument('--dynamic-size', action='store_true', default=False, + help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to checkpoint (default: none)') + + +def main(): + args = parser.parse_args() + + args.pretrained = True + if args.checkpoint: + args.pretrained = False + + print("==> Creating PyTorch {} model".format(args.model)) + # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers + # for models using SAME padding + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + exportable=True) + + model.eval() + + example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True) + + # Run model once before export trace, sets padding for models with Conv2dSameExport. This means + # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for + # the input img_size specified in this script. + # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to + # issues in the tracing of the dynamic padding or errors attempting to export the model after jit + # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... + model(example_input) + + print("==> Exporting model to ONNX format at '{}'".format(args.output)) + input_names = ["input0"] + output_names = ["output0"] + dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} + if args.dynamic_size: + dynamic_axes['input0'][2] = 'height' + dynamic_axes['input0'][3] = 'width' + if args.aten_fallback: + export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + else: + export_type = torch.onnx.OperatorExportTypes.ONNX + + torch_out = torch.onnx._export( + model, example_input, args.output, export_params=True, verbose=True, input_names=input_names, + output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes, + opset_version=args.opset, operator_export_type=export_type) + + print("==> Loading and checking exported model from '{}'".format(args.output)) + onnx_model = onnx.load(args.output) + onnx.checker.check_model(onnx_model) # assuming throw on error + print("==> Passed") + + if args.keep_init and args.aten_fallback: + import caffe2.python.onnx.backend as onnx_caffe2 + # Caffe2 loading only works properly in newer PyTorch/ONNX combos when + # keep_initializers_as_inputs and aten_fallback are set to True. + print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output)) + caffe2_backend = onnx_caffe2.prepare(onnx_model) + B = {onnx_model.graph.input[0].name: x.data.numpy()} + c2_out = caffe2_backend.run(B)[0] + np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5) + print("==> Passed") + + +if __name__ == '__main__': + main() diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..ee20bbf9f0f9473370489512eb96ca0b570b5388 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py @@ -0,0 +1,84 @@ +""" ONNX optimization script + +Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. + +NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), +it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). + +Copyright 2020 Ross Wightman +""" +import argparse +import warnings + +import onnx +from onnx import optimizer + + +parser = argparse.ArgumentParser(description="Optimize ONNX model") + +parser.add_argument("model", help="The ONNX model") +parser.add_argument("--output", required=True, help="The optimized model output filename") + + +def traverse_graph(graph, prefix=''): + content = [] + indent = prefix + ' ' + graphs = [] + num_nodes = 0 + for node in graph.node: + pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) + assert isinstance(gs, list) + content.append(pn) + graphs.extend(gs) + num_nodes += 1 + for g in graphs: + g_count, g_str = traverse_graph(g) + content.append('\n' + g_str) + num_nodes += g_count + return num_nodes, '\n'.join(content) + + +def main(): + args = parser.parse_args() + onnx_model = onnx.load(args.model) + num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) + + # Optimizer passes to perform + passes = [ + #'eliminate_deadend', + 'eliminate_identity', + 'eliminate_nop_dropout', + 'eliminate_nop_pad', + 'eliminate_nop_transpose', + 'eliminate_unused_initializer', + 'extract_constant_to_initializer', + 'fuse_add_bias_into_conv', + 'fuse_bn_into_conv', + 'fuse_consecutive_concats', + 'fuse_consecutive_reduce_unsqueeze', + 'fuse_consecutive_squeezes', + 'fuse_consecutive_transposes', + #'fuse_matmul_add_bias_into_gemm', + 'fuse_pad_into_conv', + #'fuse_transpose_into_gemm', + #'lift_lexical_references', + ] + + # Apply the optimization on the original serialized model + # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing + # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 + # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. + warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." + "Try onnxruntime optimization if this doesn't work.") + optimized_model = optimizer.optimize(onnx_model, passes) + + num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) + print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) + print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) + + # Save the ONNX model + onnx.save(optimized_model, args.output) + + +if __name__ == "__main__": + main() diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py new file mode 100644 index 0000000000000000000000000000000000000000..44399aafababcdf6b84147a0613eb0909730db4b --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py @@ -0,0 +1,27 @@ +import argparse + +import onnx +from caffe2.python.onnx.backend import Caffe2Backend + + +parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2") + +parser.add_argument("model", help="The ONNX model") +parser.add_argument("--c2-prefix", required=True, + help="The output file prefix for the caffe2 model init and predict file. ") + + +def main(): + args = parser.parse_args() + onnx_model = onnx.load(args.model) + caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) + caffe2_init_str = caffe2_init.SerializeToString() + with open(args.c2_prefix + '.init.pb', "wb") as f: + f.write(caffe2_init_str) + caffe2_predict_str = caffe2_predict.SerializeToString() + with open(args.c2_prefix + '.predict.pb', "wb") as f: + f.write(caffe2_predict_str) + + +if __name__ == "__main__": + main() diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3e4fb141b6ef660dcc5b447fd9f368a2ea19a0 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py @@ -0,0 +1,112 @@ +""" ONNX-runtime validation script + +This script was created to verify accuracy and performance of exported ONNX +models running with the onnxruntime. It utilizes the PyTorch dataloader/processing +pipeline for a fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +import onnxruntime +from data import create_loader, resolve_data_config, Dataset +from utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--onnx-input', default='', type=str, metavar='PATH', + help='path to onnx model/weights file') +parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', + help='path to output optimized onnx graph') +parser.add_argument('--profile', action='store_true', default=False, + help='Enable profiler output.') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + + # Set graph optimization level + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.profile: + sess_options.enable_profiling = True + if args.onnx_output_opt: + sess_options.optimized_model_filepath = args.onnx_output_opt + + session = onnxruntime.InferenceSession(args.onnx_input, sess_options) + + data_config = resolve_data_config(None, args) + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + input_name = session.get_inputs()[0].name + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + output = session.run([], {input_name: input.data.numpy()}) + output = output[0] + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, + ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ac3ffc13bae15f9b11f7cbe3705760056ecd7f13 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.2.0 +torchvision>=0.4.0 diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..023e4c30f98164595964423e3a83eefaf7ffdad6 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py @@ -0,0 +1,47 @@ +""" Setup +""" +from setuptools import setup, find_packages +from codecs import open +from os import path + +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +exec(open('geffnet/version.py').read()) +setup( + name='geffnet', + version=__version__, + description='(Generic) EfficientNets for PyTorch', + long_description=long_description, + long_description_content_type='text/markdown', + url='https://github.com/rwightman/gen-efficientnet-pytorch', + author='Ross Wightman', + author_email='hello@rwightman.com', + classifiers=[ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + + # Note that this is a string of words separated by whitespace, not a list. + keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', + packages=find_packages(exclude=['data']), + install_requires=['torch >= 1.4', 'torchvision'], + python_requires='>=3.6', +) diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d327e8bd8120c5cd09ae6c15c3991ccbe27f6c1f --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py @@ -0,0 +1,52 @@ +import os + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + diff --git a/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/validate.py b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd44fbb3165ef81ef81251b6299f6aaa80bf2c2 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/validate.py @@ -0,0 +1,166 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import time +import torch +import torch.nn as nn +import torch.nn.parallel +from contextlib import suppress + +import geffnet +from data import Dataset, create_loader, resolve_data_config +from utils import accuracy, AverageMeter + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', + help='') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use native Torch AMP mixed precision.') + + +def main(): + args = parser.parse_args() + + if not args.checkpoint and not args.pretrained: + args.pretrained = True + + amp_autocast = suppress # do nothing + if args.amp: + if not has_native_amp: + print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.") + else: + amp_autocast = torch.cuda.amp.autocast + + # create model + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + scriptable=args.torchscript) + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + data_config = resolve_data_config(model, args) + + criterion = nn.CrossEntropyLoss() + + if not args.no_cuda: + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=not args.no_cuda, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + model.eval() + end = time.time() + with torch.no_grad(): + for i, (input, target) in enumerate(loader): + if not args.no_cuda: + target = target.cuda() + input = input.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + # compute output + with amp_autocast(): + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +if __name__ == '__main__': + main() diff --git a/controlnet_aux/normalbae/nets/submodules/encoder.py b/controlnet_aux/normalbae/nets/submodules/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7149ca3c0cf2b6e019105af7e645cfbb3eda11 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/encoder.py @@ -0,0 +1,34 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b5_ap' + print('Loading base model ()...'.format(basemodel_name), end='') + repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') + basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') + print('Done.') + + # Remove last layer + print('Removing last two layers (global_pool & classifier).') + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + diff --git a/controlnet_aux/normalbae/nets/submodules/submodules.py b/controlnet_aux/normalbae/nets/submodules/submodules.py new file mode 100644 index 0000000000000000000000000000000000000000..409733351bd6ab5d191c800aff1bc05bfa4cb6f8 --- /dev/null +++ b/controlnet_aux/normalbae/nets/submodules/submodules.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +######################################################################################################################## + + +# Upsample + BatchNorm +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleBN, self).__init__() + + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +# Upsample + GroupNorm + Weight Standardization +class UpSampleGN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleGN, self).__init__() + + self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU(), + Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +# Conv2d with weight standardization +class Conv2d(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# normalize +def norm_normalize(norm_out): + min_kappa = 0.01 + norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) + norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 + kappa = F.elu(kappa) + 1.0 + min_kappa + final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) + return final_out + + +# uncertainty-guided sampling (only used during training) +@torch.no_grad() +def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): + device = init_normal.device + B, _, H, W = init_normal.shape + N = int(sampling_ratio * H * W) + beta = beta + + # uncertainty map + uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W + + # gt_invalid_mask (B, H, W) + if gt_norm_mask is not None: + gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') + gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 + uncertainty_map[gt_invalid_mask] = -1e4 + + # (B, H*W) + _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) + + # importance sampling + if int(beta * N) > 0: + importance = idx[:, :int(beta * N)] # B, beta*N + + # remaining + remaining = idx[:, int(beta * N):] # B, H*W - beta*N + + # coverage + num_coverage = N - int(beta * N) + + if num_coverage <= 0: + samples = importance + else: + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = torch.cat((importance, coverage), dim=1) # B, N + + else: + # remaining + remaining = idx[:, :] # B, H*W + + # coverage + num_coverage = N + + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = coverage + + # point coordinates + rows_int = samples // W # 0 for first row, H-1 for last row + rows_float = rows_int / float(H-1) # 0 to 1.0 + rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 + + cols_int = samples % W # 0 for first column, W-1 for last column + cols_float = cols_int / float(W-1) # 0 to 1.0 + cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 + + point_coords = torch.zeros(B, 1, N, 2) + point_coords[:, 0, :, 0] = cols_float # x coord + point_coords[:, 0, :, 1] = rows_float # y coord + point_coords = point_coords.to(device) + return point_coords, rows_int, cols_int \ No newline at end of file diff --git a/controlnet_aux/open_pose/LICENSE b/controlnet_aux/open_pose/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd --- /dev/null +++ b/controlnet_aux/open_pose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/controlnet_aux/open_pose/__init__.py b/controlnet_aux/open_pose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e463316aa60aae6117e7131180459a12b7d1dcb8 --- /dev/null +++ b/controlnet_aux/open_pose/__init__.py @@ -0,0 +1,234 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) +# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) +# This preprocessor is licensed by CMU for non-commercial use only. + + +import os + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import json +import warnings +from typing import Callable, List, NamedTuple, Tuple, Union + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from . import util +from .body import Body, BodyResult, Keypoint +from .face import Face +from .hand import Hand + +HandResult = List[Keypoint] +FaceResult = List[Keypoint] + +class PoseResult(NamedTuple): + body: BodyResult + left_hand: Union[HandResult, None] + right_hand: Union[HandResult, None] + face: Union[FaceResult, None] + +def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): + """ + Draw the detected poses on an empty canvas. + + Args: + poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. + H (int): The height of the canvas. + W (int): The width of the canvas. + draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. + draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. + draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. + + Returns: + numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. + """ + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + for pose in poses: + if draw_body: + canvas = util.draw_bodypose(canvas, pose.body.keypoints) + + if draw_hand: + canvas = util.draw_handpose(canvas, pose.left_hand) + canvas = util.draw_handpose(canvas, pose.right_hand) + + if draw_face: + canvas = util.draw_facepose(canvas, pose.face) + + return canvas + + +class OpenposeDetector: + """ + A class for detecting human poses in images using the Openpose model. + + Attributes: + model_dir (str): Path to the directory where the pose models are stored. + """ + def __init__(self, body_estimation, hand_estimation=None, face_estimation=None): + self.body_estimation = body_estimation + self.hand_estimation = hand_estimation + self.face_estimation = face_estimation + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None, local_files_only=False): + + if pretrained_model_or_path == "lllyasviel/ControlNet": + filename = filename or "annotator/ckpts/body_pose_model.pth" + hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth" + face_filename = face_filename or "facenet.pth" + + face_pretrained_model_or_path = "lllyasviel/Annotators" + else: + filename = filename or "body_pose_model.pth" + hand_filename = hand_filename or "hand_pose_model.pth" + face_filename = face_filename or "facenet.pth" + + face_pretrained_model_or_path = pretrained_model_or_path + + if os.path.isdir(pretrained_model_or_path): + body_model_path = os.path.join(pretrained_model_or_path, filename) + hand_model_path = os.path.join(pretrained_model_or_path, hand_filename) + face_model_path = os.path.join(face_pretrained_model_or_path, face_filename) + else: + body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only) + face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only) + + body_estimation = Body(body_model_path) + hand_estimation = Hand(hand_model_path) + face_estimation = Face(face_model_path) + + return cls(body_estimation, hand_estimation, face_estimation) + + def to(self, device): + self.body_estimation.to(device) + self.hand_estimation.to(device) + self.face_estimation.to(device) + return self + + def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]: + left_hand = None + right_hand = None + H, W, _ = oriImg.shape + for x, y, w, is_left in util.handDetect(body, oriImg): + peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + + hand_result = [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + if is_left: + left_hand = hand_result + else: + right_hand = hand_result + + return left_hand, right_hand + + def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: + face = util.faceDetect(body, oriImg) + if face is None: + return None + + x, y, w = face + H, W, _ = oriImg.shape + heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) + peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + return [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + return None + + def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: + """ + Detect poses in the given image. + Args: + oriImg (numpy.ndarray): The input image for pose detection. + include_hand (bool, optional): Whether to include hand detection. Defaults to False. + include_face (bool, optional): Whether to include face detection. Defaults to False. + + Returns: + List[PoseResult]: A list of PoseResult objects containing the detected poses. + """ + oriImg = oriImg[:, :, ::-1].copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.body_estimation(oriImg) + bodies = self.body_estimation.format_body_result(candidate, subset) + + results = [] + for body in bodies: + left_hand, right_hand, face = (None,) * 3 + if include_hand: + left_hand, right_hand = self.detect_hands(body, oriImg) + if include_face: + face = self.detect_face(body, oriImg) + + results.append(PoseResult(BodyResult( + keypoints=[ + Keypoint( + x=keypoint.x / float(W), + y=keypoint.y / float(H) + ) if keypoint is not None else None + for keypoint in body.keypoints + ], + total_score=body.total_score, + total_parts=body.total_parts + ), left_hand, right_hand, face)) + + return results + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs): + if hand_and_face is not None: + warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) + include_hand = hand_and_face + include_face = hand_and_face + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + H, W, C = input_image.shape + + poses = self.detect_poses(input_image, include_hand, include_face) + canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) + + detected_map = canvas + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/open_pose/body.py b/controlnet_aux/open_pose/body.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4c74e4e1e220ee87bac3634bf78c45e87aca55 --- /dev/null +++ b/controlnet_aux/open_pose/body.py @@ -0,0 +1,260 @@ +import math +from typing import List, NamedTuple, Union + +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter + +from . import util +from .model import bodypose_model + + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +class BodyResult(NamedTuple): + # Note: Using `Union` instead of `|` operator as the ladder is a Python + # 3.10 feature. + # Annotator code should be Python 3.8 Compatible, as controlnet repo uses + # Python 3.8 environment. + # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 + keypoints: List[Union[Keypoint, None]] + total_score: float + total_parts: int + + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, oriImg): + device = next(iter(self.model.parameters())).device + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(device) + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = util.smart_resize_k(paf, fx=stride, fy=stride) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += + paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + + @staticmethod + def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: + """ + Format the body results from the candidate and subset arrays into a list of BodyResult objects. + + Args: + candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id + for each body part. + subset (np.ndarray): An array of subsets containing indices to the candidate array for each + person detected. The last two columns of each row hold the total score and total parts + of the person. + + Returns: + List[BodyResult]: A list of BodyResult objects, where each object represents a person with + detected keypoints, total score, and total parts. + """ + return [ + BodyResult( + keypoints=[ + Keypoint( + x=candidate[candidate_index][0], + y=candidate[candidate_index][1], + score=candidate[candidate_index][2], + id=candidate[candidate_index][3] + ) if candidate_index != -1 else None + for candidate_index in person[:18].astype(int) + ], + total_score=person[18], + total_parts=person[19] + ) + for person in subset + ] diff --git a/controlnet_aux/open_pose/face.py b/controlnet_aux/open_pose/face.py new file mode 100644 index 0000000000000000000000000000000000000000..41c7799af10b1f834369464862d41d8f967128c6 --- /dev/null +++ b/controlnet_aux/open_pose/face.py @@ -0,0 +1,364 @@ +import logging + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init +from torchvision.transforms import ToPILImage, ToTensor + +from . import util + + +class FaceNet(Module): + """Model the cascading heatmaps. """ + def __init__(self): + super(FaceNet, self).__init__() + # cnn to make feature map + self.relu = ReLU() + self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, + kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1) + self.conv2_1 = Conv2d( + in_channels=64, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv2_2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv3_1 = Conv2d( + in_channels=128, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_2 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_3 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_4 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv4_1 = Conv2d( + in_channels=256, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_3 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_4 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_1 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_3_CPM = Conv2d( + in_channels=512, out_channels=128, kernel_size=3, stride=1, + padding=1) + + # stage1 + self.conv6_1_CPM = Conv2d( + in_channels=128, out_channels=512, kernel_size=1, stride=1, + padding=0) + self.conv6_2_CPM = Conv2d( + in_channels=512, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage2 + self.Mconv1_stage2 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage2 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage3 + self.Mconv1_stage3 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage3 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage4 + self.Mconv1_stage4 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage4 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage5 + self.Mconv1_stage5 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage5 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage6 + self.Mconv1_stage6 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage6 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + for m in self.modules(): + if isinstance(m, Conv2d): + init.constant_(m.bias, 0) + + def forward(self, x): + """Return a list of heatmaps.""" + heatmaps = [] + + h = self.relu(self.conv1_1(x)) + h = self.relu(self.conv1_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv2_1(h)) + h = self.relu(self.conv2_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv3_1(h)) + h = self.relu(self.conv3_2(h)) + h = self.relu(self.conv3_3(h)) + h = self.relu(self.conv3_4(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv4_1(h)) + h = self.relu(self.conv4_2(h)) + h = self.relu(self.conv4_3(h)) + h = self.relu(self.conv4_4(h)) + h = self.relu(self.conv5_1(h)) + h = self.relu(self.conv5_2(h)) + h = self.relu(self.conv5_3_CPM(h)) + feature_map = h + + # stage1 + h = self.relu(self.conv6_1_CPM(h)) + h = self.conv6_2_CPM(h) + heatmaps.append(h) + + # stage2 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage2(h)) + h = self.relu(self.Mconv2_stage2(h)) + h = self.relu(self.Mconv3_stage2(h)) + h = self.relu(self.Mconv4_stage2(h)) + h = self.relu(self.Mconv5_stage2(h)) + h = self.relu(self.Mconv6_stage2(h)) + h = self.Mconv7_stage2(h) + heatmaps.append(h) + + # stage3 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage3(h)) + h = self.relu(self.Mconv2_stage3(h)) + h = self.relu(self.Mconv3_stage3(h)) + h = self.relu(self.Mconv4_stage3(h)) + h = self.relu(self.Mconv5_stage3(h)) + h = self.relu(self.Mconv6_stage3(h)) + h = self.Mconv7_stage3(h) + heatmaps.append(h) + + # stage4 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage4(h)) + h = self.relu(self.Mconv2_stage4(h)) + h = self.relu(self.Mconv3_stage4(h)) + h = self.relu(self.Mconv4_stage4(h)) + h = self.relu(self.Mconv5_stage4(h)) + h = self.relu(self.Mconv6_stage4(h)) + h = self.Mconv7_stage4(h) + heatmaps.append(h) + + # stage5 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage5(h)) + h = self.relu(self.Mconv2_stage5(h)) + h = self.relu(self.Mconv3_stage5(h)) + h = self.relu(self.Mconv4_stage5(h)) + h = self.relu(self.Mconv5_stage5(h)) + h = self.relu(self.Mconv6_stage5(h)) + h = self.Mconv7_stage5(h) + heatmaps.append(h) + + # stage6 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage6(h)) + h = self.relu(self.Mconv2_stage6(h)) + h = self.relu(self.Mconv3_stage6(h)) + h = self.relu(self.Mconv4_stage6(h)) + h = self.relu(self.Mconv5_stage6(h)) + h = self.relu(self.Mconv6_stage6(h)) + h = self.Mconv7_stage6(h) + heatmaps.append(h) + + return heatmaps + + +LOG = logging.getLogger(__name__) +TOTEN = ToTensor() +TOPIL = ToPILImage() + + +params = { + 'gaussian_sigma': 2.5, + 'inference_img_size': 736, # 368, 736, 1312 + 'heatmap_peak_thresh': 0.1, + 'crop_scale': 1.5, + 'line_indices': [ + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], + [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], + [13, 14], [14, 15], [15, 16], + [17, 18], [18, 19], [19, 20], [20, 21], + [22, 23], [23, 24], [24, 25], [25, 26], + [27, 28], [28, 29], [29, 30], + [31, 32], [32, 33], [33, 34], [34, 35], + [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], + [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], + [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], + [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], + [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], + [66, 67], [67, 60] + ], +} + + +class Face(object): + """ + The OpenPose face landmark detector model. + + Args: + inference_size: set the size of the inference image size, suggested: + 368, 736, 1312, default 736 + gaussian_sigma: blur the heatmaps, default 2.5 + heatmap_peak_thresh: return landmark if over threshold, default 0.1 + + """ + def __init__(self, face_model_path, + inference_size=None, + gaussian_sigma=None, + heatmap_peak_thresh=None): + self.inference_size = inference_size or params["inference_img_size"] + self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] + self.model = FaceNet() + self.model.load_state_dict(torch.load(face_model_path)) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, face_img): + device = next(iter(self.model.parameters())).device + H, W, C = face_img.shape + + w_size = 384 + x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 + + x_data = x_data.to(device) + + with torch.no_grad(): + hs = self.model(x_data[None, ...]) + heatmaps = F.interpolate( + hs[-1], + (H, W), + mode='bilinear', align_corners=True).cpu().numpy()[0] + return heatmaps + + def compute_peaks_from_heatmaps(self, heatmaps): + all_peaks = [] + for part in range(heatmaps.shape[0]): + map_ori = heatmaps[part].copy() + binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) + + if np.sum(binary) == 0: + continue + + positions = np.where(binary > 0.5) + intensities = map_ori[positions] + mi = np.argmax(intensities) + y, x = positions[0][mi], positions[1][mi] + all_peaks.append([x, y]) + + return np.array(all_peaks) \ No newline at end of file diff --git a/controlnet_aux/open_pose/hand.py b/controlnet_aux/open_pose/hand.py new file mode 100644 index 0000000000000000000000000000000000000000..1387c4238c8c3856bb9622edb9b4c883e26c1d59 --- /dev/null +++ b/controlnet_aux/open_pose/hand.py @@ -0,0 +1,90 @@ +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter +from skimage.measure import label + +from . import util +from .model import handpose_model + + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, oriImgRaw): + device = next(iter(self.model.parameters())).device + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize for x in scale_search] + + wsize = 128 + heatmap_avg = np.zeros((wsize, wsize, 22)) + + Hr, Wr, Cr = oriImgRaw.shape + + oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize(oriImg, (scale, scale)) + + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(device) + + with torch.no_grad(): + output = self.model(data).cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (wsize, wsize)) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + y = int(float(y) * float(Hr) / float(wsize)) + x = int(float(x) * float(Wr) / float(wsize)) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) \ No newline at end of file diff --git a/controlnet_aux/open_pose/model.py b/controlnet_aux/open_pose/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3d47268986f8018b2c75307a7725d364b175fe --- /dev/null +++ b/controlnet_aux/open_pose/model.py @@ -0,0 +1,217 @@ +import torch +from collections import OrderedDict + +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 diff --git a/controlnet_aux/open_pose/util.py b/controlnet_aux/open_pose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f10ca2dfcbf66fb6e8697503d7ffb336b48b865a --- /dev/null +++ b/controlnet_aux/open_pose/util.py @@ -0,0 +1,383 @@ +import math +import numpy as np +import cv2 +from typing import List, Tuple, Union + +from .body import BodyResult, Keypoint + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: + """ + Draw keypoints and limbs representing body pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose. + keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + H, W, C = canvas.shape + stickwidth = 4 + + limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], + [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], + [2, 1], [1, 15], [15, 17], [1, 16], + [16, 18], + ] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for (k1_index, k2_index), color in zip(limbSeq, colors): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + continue + + Y = np.array([keypoint1.x, keypoint2.x]) * float(W) + X = np.array([keypoint1.y, keypoint2.y]) * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) + + for keypoint, color in zip(keypoints, colors): + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + + return canvas + + +def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + import matplotlib + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + + x1 = int(k1.x * W) + y1 = int(k1.y * H) + x2 = int(k2.x * W) + y2 = int(k2.y * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints representing face pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, C = canvas.shape + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: + """ + Detect hands in the input body pose keypoints and calculate the bounding box for each hand. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left + corner of the bounding box, the width (height) of the bounding box, and + a boolean flag indicating whether the hand is a left hand (True) or a + right hand (False). + + Notes: + - The width and height of the bounding boxes are equal since the network requires squared input. + - The minimum bounding box size is 20 pixels. + """ + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + left_shoulder = keypoints[5] + left_elbow = keypoints[6] + left_wrist = keypoints[7] + right_shoulder = keypoints[2] + right_elbow = keypoints[3] + right_wrist = keypoints[4] + + # if any of three not detected + has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist)) + has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) + if not (has_left or has_right): + return [] + + hands = [] + #left hand + if has_left: + hands.append([ + left_shoulder.x, left_shoulder.y, + left_elbow.x, left_elbow.y, + left_wrist.x, left_wrist.y, + True + ]) + # right hand + if has_right: + hands.append([ + right_shoulder.x, right_shoulder.y, + right_elbow.x, right_elbow.y, + right_wrist.x, right_wrist.y, + False + ]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append((int(x), int(y), int(width), is_left)) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: + """ + Detect the face in the input body pose keypoints and calculate the bounding box for the face. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the + bounding box and the width (height) of the bounding box, or None if the + face is not detected or the bounding box width is less than 20 pixels. + + Notes: + - The width and height of the bounding box are equal. + - The minimum bounding box size is 20 pixels. + """ + # left right eye ear 14 15 16 17 + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + head = keypoints[0] + left_eye = keypoints[14] + right_eye = keypoints[15] + left_ear = keypoints[16] + right_ear = keypoints[17] + + if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): + return None + + width = 0.0 + x0, y0 = head.x, head.y + + if left_eye is not None: + x1, y1 = left_eye.x, left_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if right_eye is not None: + x1, y1 = right_eye.x, right_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if left_ear is not None: + x1, y1 = left_ear.x, left_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if right_ear is not None: + x1, y1 = right_ear.x, right_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + return int(x), int(y), int(width) + else: + return None + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/controlnet_aux/pidi/LICENSE b/controlnet_aux/pidi/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..913b6cf92c19d37b6ee4f7bc99c65f655e7f840c --- /dev/null +++ b/controlnet_aux/pidi/LICENSE @@ -0,0 +1,21 @@ +It is just for research purpose, and commercial use should be contacted with authors first. + +Copyright (c) 2021 Zhuo Su + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/controlnet_aux/pidi/__init__.py b/controlnet_aux/pidi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8269973e0fb66ec280458d9f8757e74a63846de --- /dev/null +++ b/controlnet_aux/pidi/__init__.py @@ -0,0 +1,84 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, nms, resize_image, safe_step +from .model import pidinet + + +class PidiNetDetector: + def __init__(self, netNetwork): + self.netNetwork = netNetwork + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): + filename = filename or "table5_pidinet.pth" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + netNetwork = pidinet() + netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) + netNetwork.eval() + + return cls(netNetwork) + + def to(self, device): + self.netNetwork.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=False, **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + device = next(iter(self.netNetwork.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + assert input_image.ndim == 3 + input_image = input_image[:, :, ::-1].copy() + with torch.no_grad(): + image_pidi = torch.from_numpy(input_image).float().to(device) + image_pidi = image_pidi / 255.0 + image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') + edge = self.netNetwork(image_pidi)[-1] + edge = edge.cpu().numpy() + if apply_filter: + edge = edge > 0.5 + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge[0, 0] + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/pidi/model.py b/controlnet_aux/pidi/model.py new file mode 100644 index 0000000000000000000000000000000000000000..16595b35a4f75a6d2b0e832e24b6e11706d77326 --- /dev/null +++ b/controlnet_aux/pidi/model.py @@ -0,0 +1,681 @@ +""" +Author: Zhuo Su, Wenzhe Liu +Date: Feb 18, 2021 +""" + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + +nets = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv', + }, + } + +def createConvFunc(op_type): + assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + + if op_type == 'cd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y - yc + return func + elif op_type == 'ad': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise + y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + return func + elif op_type == 'rd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + return func + else: + print('impossible to be here unless you force that') + return None + +class Conv2d(nn.Module): + def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + + return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class CSAM(nn.Module): + """ + Compact Spatial Attention Module + """ + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + +class CDCM(nn.Module): + """ + Compact Dilation Convolution based Module + """ + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) + self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) + self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) + self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + """ + Reduce feature maps into a single edge map + """ + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride=stride + + self.stride=stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + +class PDCBlock_converted(nn.Module): + """ + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride=stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) + else: + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + +class PiDiNet(nn.Module): + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d(3, self.inplane, + kernel_size=init_kernel_size, padding=init_padding, bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + # print('initialization done') + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) + + outputs = [e1, e2, e3, e4] + + output = self.classifier(torch.cat(outputs, dim=1)) + #if not self.training: + # return torch.sigmoid(output) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs + +def config_model(model): + model_options = list(nets.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + # print(str(nets[model])) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = nets[model][layer_name] + pdcs.append(createConvFunc(op)) + + return pdcs + +def pidinet(): + pdcs = config_model('carv4') + dil = 24 #if args.dil else None + return PiDiNet(60, pdcs, dil=dil, sa=True) + + +if __name__ == '__main__': + model = pidinet() + ckp = torch.load('table5_pidinet.pth')['state_dict'] + model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) + im = cv2.imread('examples/test_my/cat_v4.png') + im = img2tensor(im).unsqueeze(0)/255. + res = model(im)[-1] + res = res>0.5 + res = res.float() + res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8) + print(res.shape) + cv2.imwrite('edge.png', res) diff --git a/controlnet_aux/processor.py b/controlnet_aux/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..12cb6b085ea080d39c225ebf7d7f13061b42d125 --- /dev/null +++ b/controlnet_aux/processor.py @@ -0,0 +1,148 @@ +""" +This file contains a Processor that can be used to process images with controlnet aux processors +""" +import io +import logging +from typing import Dict, Optional, Union + +from PIL import Image + +from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, + LeresDetector, LineartAnimeDetector, + LineartDetector, MediapipeFaceDetector, + MidasDetector, MLSDdetector, NormalBaeDetector, + OpenposeDetector, PidiNetDetector, ZoeDetector, + DWposeDetector) + +LOGGER = logging.getLogger(__name__) + + +MODELS = { + # checkpoint models + 'scribble_hed': {'class': HEDdetector, 'checkpoint': True}, + 'softedge_hed': {'class': HEDdetector, 'checkpoint': True}, + 'scribble_hedsafe': {'class': HEDdetector, 'checkpoint': True}, + 'softedge_hedsafe': {'class': HEDdetector, 'checkpoint': True}, + 'depth_midas': {'class': MidasDetector, 'checkpoint': True}, + 'mlsd': {'class': MLSDdetector, 'checkpoint': True}, + 'openpose': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_face': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_faceonly': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_full': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_hand': {'class': OpenposeDetector, 'checkpoint': True}, + 'dwpose': {'class': DWposeDetector, 'checkpoint': True}, + 'scribble_pidinet': {'class': PidiNetDetector, 'checkpoint': True}, + 'softedge_pidinet': {'class': PidiNetDetector, 'checkpoint': True}, + 'scribble_pidsafe': {'class': PidiNetDetector, 'checkpoint': True}, + 'softedge_pidsafe': {'class': PidiNetDetector, 'checkpoint': True}, + 'normal_bae': {'class': NormalBaeDetector, 'checkpoint': True}, + 'lineart_coarse': {'class': LineartDetector, 'checkpoint': True}, + 'lineart_realistic': {'class': LineartDetector, 'checkpoint': True}, + 'lineart_anime': {'class': LineartAnimeDetector, 'checkpoint': True}, + 'depth_zoe': {'class': ZoeDetector, 'checkpoint': True}, + 'depth_leres': {'class': LeresDetector, 'checkpoint': True}, + 'depth_leres++': {'class': LeresDetector, 'checkpoint': True}, + # instantiate + 'shuffle': {'class': ContentShuffleDetector, 'checkpoint': False}, + 'mediapipe_face': {'class': MediapipeFaceDetector, 'checkpoint': False}, + 'canny': {'class': CannyDetector, 'checkpoint': False}, +} + + +MODEL_PARAMS = { + 'scribble_hed': {'scribble': True}, + 'softedge_hed': {'scribble': False}, + 'scribble_hedsafe': {'scribble': True, 'safe': True}, + 'softedge_hedsafe': {'scribble': False, 'safe': True}, + 'depth_midas': {}, + 'mlsd': {}, + 'openpose': {'include_body': True, 'include_hand': False, 'include_face': False}, + 'openpose_face': {'include_body': True, 'include_hand': False, 'include_face': True}, + 'openpose_faceonly': {'include_body': False, 'include_hand': False, 'include_face': True}, + 'openpose_full': {'include_body': True, 'include_hand': True, 'include_face': True}, + 'openpose_hand': {'include_body': False, 'include_hand': True, 'include_face': False}, + 'dwpose': {}, + 'scribble_pidinet': {'safe': False, 'scribble': True}, + 'softedge_pidinet': {'safe': False, 'scribble': False}, + 'scribble_pidsafe': {'safe': True, 'scribble': True}, + 'softedge_pidsafe': {'safe': True, 'scribble': False}, + 'normal_bae': {}, + 'lineart_realistic': {'coarse': False}, + 'lineart_coarse': {'coarse': True}, + 'lineart_anime': {}, + 'canny': {}, + 'shuffle': {}, + 'depth_zoe': {}, + 'depth_leres': {'boost': False}, + 'depth_leres++': {'boost': True}, + 'mediapipe_face': {}, +} + +CHOICES = f"Choices for the processor are {list(MODELS.keys())}" + + +class Processor: + def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: + """Processor that can be used to process images with controlnet aux processors + + Args: + processor_id (str): processor name, options are 'hed, midas, mlsd, openpose, + pidinet, normalbae, lineart, lineart_coarse, lineart_anime, + canny, content_shuffle, zoe, mediapipe_face + params (Optional[Dict]): parameters for the processor + """ + LOGGER.info(f"Loading {processor_id}") + + if processor_id not in MODELS: + raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}") + + self.processor_id = processor_id + self.processor = self.load_processor(self.processor_id) + + # load default params + self.params = MODEL_PARAMS[self.processor_id] + # update with user params + if params: + self.params.update(params) + + def load_processor(self, processor_id: str) -> 'Processor': + """Load controlnet aux processors + + Args: + processor_id (str): processor name + + Returns: + Processor: controlnet aux processor + """ + processor = MODELS[processor_id]['class'] + + # check if the proecssor is a checkpoint model + if MODELS[processor_id]['checkpoint']: + processor = processor.from_pretrained("lllyasviel/Annotators") + else: + processor = processor() + return processor + + def __call__(self, image: Union[Image.Image, bytes], + to_pil: bool = True) -> Union[Image.Image, bytes]: + """processes an image with a controlnet aux processor + + Args: + image (Union[Image.Image, bytes]): input image in bytes or PIL Image + to_pil (bool): whether to return bytes or PIL Image + + Returns: + Union[Image.Image, bytes]: processed image in bytes or PIL Image + """ + # check if bytes or PIL Image + if isinstance(image, bytes): + image = Image.open(io.BytesIO(image)).convert("RGB") + + processed_image = self.processor(image, **self.params) + + if to_pil: + return processed_image + else: + output_bytes = io.BytesIO() + processed_image.save(output_bytes, format='JPEG') + return output_bytes.getvalue() diff --git a/controlnet_aux/segment_anything/__init__.py b/controlnet_aux/segment_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..048c096c3a43e150b61cc970f34cedf235e453af --- /dev/null +++ b/controlnet_aux/segment_anything/__init__.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from typing import Union + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from .automatic_mask_generator import SamAutomaticMaskGenerator +from .build_sam import sam_model_registry + + +class SamDetector: + def __init__(self, mask_generator: SamAutomaticMaskGenerator): + self.mask_generator = mask_generator + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, model_type="vit_h", filename="sam_vit_h_4b8939.pth", subfolder=None, cache_dir=None): + """ + Possible model_type : vit_h, vit_l, vit_b, vit_t + download weights from https://github.com/facebookresearch/segment-anything + """ + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder, cache_dir=cache_dir) + + sam = sam_model_registry[model_type](checkpoint=model_path) + + if torch.cuda.is_available(): + sam.to("cuda") + + mask_generator = SamAutomaticMaskGenerator(sam) + + return cls(mask_generator) + + + def show_anns(self, anns): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + h, w = anns[0]['segmentation'].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + for ann in sorted_anns: + m = ann['segmentation'] + img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) + for i in range(3): + img[:,:,i] = np.random.randint(255, dtype=np.uint8) + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255))) + + return np.array(final_img, dtype=np.uint8) + + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> Image.Image: + if "image" in kwargs: + warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("image") + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + # Generate Masks + masks = self.mask_generator.generate(input_image) + # Create map + map = self.show_anns(masks) + + detected_map = map + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/segment_anything/automatic_mask_generator.py b/controlnet_aux/segment_anything/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a8c969207f119feff7087f94e044403acdff00 --- /dev/null +++ b/controlnet_aux/segment_anything/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/controlnet_aux/segment_anything/build_sam.py b/controlnet_aux/segment_anything/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..9a52c506b69d29ee2356cc0e62274fe6f6ee075b --- /dev/null +++ b/controlnet_aux/segment_anything/build_sam.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "vit_t": build_sam_vit_t, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam + + diff --git a/controlnet_aux/segment_anything/modeling/__init__.py b/controlnet_aux/segment_anything/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa261b8356b8c1174139c19782657abca0cfec2 --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer +from .tiny_vit_sam import TinyViT diff --git a/controlnet_aux/segment_anything/modeling/common.py b/controlnet_aux/segment_anything/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/controlnet_aux/segment_anything/modeling/image_encoder.py b/controlnet_aux/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..66351d9d7c589be693f4b3485901d3bdfed54d4a --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/controlnet_aux/segment_anything/modeling/mask_decoder.py b/controlnet_aux/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2fdb03d535a91fa725d1ec4e92a7a1f217dfe0 --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/controlnet_aux/segment_anything/modeling/prompt_encoder.py b/controlnet_aux/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3 --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/controlnet_aux/segment_anything/modeling/sam.py b/controlnet_aux/segment_anything/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..45b9e7c56d10cc47e7ed0739e35d850bfccbb257 --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/sam.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple, Union + +from .tiny_vit_sam import TinyViT +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: Union[ImageEncoderViT, TinyViT], + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py b/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..d06e6b56c65206943467b3bc7422a6b96f2ec205 --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py @@ -0,0 +1,716 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath,\ + to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c=2 + if(out_dim==320 or out_dim==448 or out_dim==576): + stride_c=1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size=img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + #print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B,_,C=x.size() + x = x.view(B, 64, 64, C) + x=x.permute(0, 3, 1, 2) + x=self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + #x = self.norm_head(x) + #x = self.head(x) + return x + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) diff --git a/controlnet_aux/segment_anything/modeling/transformer.py b/controlnet_aux/segment_anything/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/controlnet_aux/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/controlnet_aux/segment_anything/predictor.py b/controlnet_aux/segment_anything/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a3820fb7de8647e5d6adf229debc498b33caad62 --- /dev/null +++ b/controlnet_aux/segment_anything/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/controlnet_aux/segment_anything/utils/__init__.py b/controlnet_aux/segment_anything/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/controlnet_aux/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/controlnet_aux/segment_anything/utils/amg.py b/controlnet_aux/segment_anything/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..be064071ef399fea96c673ad173689656c23534a --- /dev/null +++ b/controlnet_aux/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/controlnet_aux/segment_anything/utils/onnx.py b/controlnet_aux/segment_anything/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe --- /dev/null +++ b/controlnet_aux/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/controlnet_aux/segment_anything/utils/transforms.py b/controlnet_aux/segment_anything/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85 --- /dev/null +++ b/controlnet_aux/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/controlnet_aux/shuffle/__init__.py b/controlnet_aux/shuffle/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e50f7cd0058f1765eb0133f8e0879e007608c01f --- /dev/null +++ b/controlnet_aux/shuffle/__init__.py @@ -0,0 +1,100 @@ +import warnings + +import cv2 +import numpy as np +from PIL import Image + +from ..util import HWC3, img2mask, make_noise_disk, resize_image + + +class ContentShuffleDetector: + def __call__(self, input_image, h=None, w=None, f=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + H, W, C = input_image.shape + if h is None: + h = H + if w is None: + w = W + if f is None: + f = 256 + x = make_noise_disk(h, w, 1, f) * float(W - 1) + y = make_noise_disk(h, w, 1, f) * float(H - 1) + flow = np.concatenate([x, y], axis=2).astype(np.float32) + detected_map = cv2.remap(input_image, flow, None, cv2.INTER_LINEAR) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + + +class ColorShuffleDetector: + def __call__(self, img): + H, W, C = img.shape + F = np.random.randint(64, 384) + A = make_noise_disk(H, W, 3, F) + B = make_noise_disk(H, W, 3, F) + C = (A + B) / 2.0 + A = (C + (A - C) * 3.0).clip(0, 1) + B = (C + (B - C) * 3.0).clip(0, 1) + L = img.astype(np.float32) / 255.0 + Y = A * L + B * (1 - L) + Y -= np.min(Y, axis=(0, 1), keepdims=True) + Y /= np.maximum(np.max(Y, axis=(0, 1), keepdims=True), 1e-5) + Y *= 255.0 + return Y.clip(0, 255).astype(np.uint8) + + +class GrayDetector: + def __call__(self, img): + eps = 1e-5 + X = img.astype(np.float32) + r, g, b = X[:, :, 0], X[:, :, 1], X[:, :, 2] + kr, kg, kb = [random.random() + eps for _ in range(3)] + ks = kr + kg + kb + kr /= ks + kg /= ks + kb /= ks + Y = r * kr + g * kg + b * kb + Y = np.stack([Y] * 3, axis=2) + return Y.clip(0, 255).astype(np.uint8) + + +class DownSampleDetector: + def __call__(self, img, level=3, k=16.0): + h = img.astype(np.float32) + for _ in range(level): + h += np.random.normal(loc=0.0, scale=k, size=h.shape) + h = cv2.pyrDown(h) + for _ in range(level): + h = cv2.pyrUp(h) + h += np.random.normal(loc=0.0, scale=k, size=h.shape) + return h.clip(0, 255).astype(np.uint8) + + +class Image2MaskShuffleDetector: + def __init__(self, resolution=(640, 512)): + self.H, self.W = resolution + + def __call__(self, img): + m = img2mask(img, self.H, self.W) + m *= 255.0 + return m.clip(0, 255).astype(np.uint8) diff --git a/controlnet_aux/teed/Fsmish.py b/controlnet_aux/teed/Fsmish.py new file mode 100644 index 0000000000000000000000000000000000000000..69691029aee958e66076a9990f258546ee3c7eaf --- /dev/null +++ b/controlnet_aux/teed/Fsmish.py @@ -0,0 +1,19 @@ +""" +Script based on: +Wang, Xueliang, Honge Ren, and Achuan Wang. + "Smish: A Novel Activation Function for Deep Learning Methods. + " Electronics 11.4 (2022): 540. +""" + +# import pytorch +import torch + + +@torch.jit.script +def smish(input): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x)))) + See additional documentation for mish class. + """ + return input * torch.tanh(torch.log(1 + torch.sigmoid(input))) diff --git a/controlnet_aux/teed/LICENSE.txt b/controlnet_aux/teed/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a99ffdd7372b1bfa44ea302330343cb7370d0e9 --- /dev/null +++ b/controlnet_aux/teed/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Xavier Soria Poma + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/controlnet_aux/teed/Xsmish.py b/controlnet_aux/teed/Xsmish.py new file mode 100644 index 0000000000000000000000000000000000000000..67b6d34a153e64483ef4140c90b0269d15f446f7 --- /dev/null +++ b/controlnet_aux/teed/Xsmish.py @@ -0,0 +1,41 @@ +""" +Script based on: +Wang, Xueliang, Honge Ren, and Achuan Wang. + "Smish: A Novel Activation Function for Deep Learning Methods. + " Electronics 11.4 (2022): 540. +smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x))) +""" + +# import pytorch +# import activation functions +from torch import nn + +from .Fsmish import smish + + +class Smish(nn.Module): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + Examples: + >>> m = Mish() + >>> input = torch.randn(2) + >>> output = m(input) + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html + """ + + def __init__(self): + """ + Init method. + """ + super().__init__() + + def forward(self, input): + """ + Forward pass of the function. + """ + return smish(input) diff --git a/controlnet_aux/teed/__init__.py b/controlnet_aux/teed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..378ce96a0f700de19777ff34a5a9c57f4317037e --- /dev/null +++ b/controlnet_aux/teed/__init__.py @@ -0,0 +1,84 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image, safe_step +from .ted import TED + + +class TEEDdetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None): + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download( + pretrained_model_or_path, filename, subfolder=subfolder + ) + + model = TED() + model.load_state_dict(torch.load(model_path, map_location="cpu")) + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__( + self, + input_image, + detect_resolution=512, + safe_steps=2, + output_type="pil", + ): + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + original_height, original_width, _ = input_image.shape + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + height, width, _ = input_image.shape + with torch.no_grad(): + image_teed = torch.from_numpy(input_image.copy()).float().to(device) + image_teed = rearrange(image_teed, "h w c -> 1 c h w") + edges = self.model(image_teed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [ + cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) + for e in edges + ] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + if safe_steps != 0: + edge = safe_step(edge, safe_steps) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge + detected_map = HWC3(detected_map) + + detected_map = cv2.resize( + detected_map, + (original_width, original_height), + interpolation=cv2.INTER_LINEAR, + ) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/teed/ted.py b/controlnet_aux/teed/ted.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6a87cff77a6ced39f5274eb7197e75856a83eb --- /dev/null +++ b/controlnet_aux/teed/ted.py @@ -0,0 +1,332 @@ +# Original from: https://github.com/xavysp/TEED +# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3 +# with a Slightly modification +# LDC parameters: +# 155665 +# TED > 58K + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .Fsmish import smish as Fsmish +from .Xsmish import Smish + + +def weight_init(m): + if isinstance(m, (nn.Conv2d,)): + torch.nn.init.xavier_normal_(m.weight, gain=1.0) + + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + # for fusion layer + if isinstance(m, (nn.ConvTranspose2d,)): + torch.nn.init.xavier_normal_(m.weight, gain=1.0) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + +class CoFusion(nn.Module): + # from LDC + + def __init__(self, in_ch, out_ch): + super(CoFusion, self).__init__() + self.conv1 = nn.Conv2d( + in_ch, 32, kernel_size=3, stride=1, padding=1 + ) # before 64 + self.conv3 = nn.Conv2d( + 32, out_ch, kernel_size=3, stride=1, padding=1 + ) # before 64 instead of 32 + self.relu = nn.ReLU() + self.norm_layer1 = nn.GroupNorm(4, 32) # before 64 + + def forward(self, x): + # fusecat = torch.cat(x, dim=1) + attn = self.relu(self.norm_layer1(self.conv1(x))) + attn = F.softmax(self.conv3(attn), dim=1) + return ((x * attn).sum(1)).unsqueeze(1) + + +class CoFusion2(nn.Module): + # TEDv14-3 + def __init__(self, in_ch, out_ch): + super(CoFusion2, self).__init__() + self.conv1 = nn.Conv2d( + in_ch, 32, kernel_size=3, stride=1, padding=1 + ) # before 64 + # self.conv2 = nn.Conv2d(32, 32, kernel_size=3, + # stride=1, padding=1)# before 64 + self.conv3 = nn.Conv2d( + 32, out_ch, kernel_size=3, stride=1, padding=1 + ) # before 64 instead of 32 + self.smish = Smish() # nn.ReLU(inplace=True) + + def forward(self, x): + # fusecat = torch.cat(x, dim=1) + attn = self.conv1(self.smish(x)) + attn = self.conv3(self.smish(attn)) # before , )dim=1) + + # return ((fusecat * attn).sum(1)).unsqueeze(1) + return ((x * attn).sum(1)).unsqueeze(1) + + +class DoubleFusion(nn.Module): + # TED fusion before the final edge map prediction + def __init__(self, in_ch, out_ch): + super(DoubleFusion, self).__init__() + self.DWconv1 = nn.Conv2d( + in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch + ) # before 64 + self.PSconv1 = nn.PixelShuffle(1) + + self.DWconv2 = nn.Conv2d( + 24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24 + ) # before 64 instead of 32 + + self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()# + + def forward(self, x): + # fusecat = torch.cat(x, dim=1) + attn = self.PSconv1( + self.DWconv1(self.AF(x)) + ) # #TEED best res TEDv14 [8, 32, 352, 352] + + attn2 = self.PSconv1( + self.DWconv2(self.AF(attn)) + ) # #TEED best res TEDv14[8, 3, 352, 352] + + return Fsmish(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res + + +class _DenseLayer(nn.Sequential): + def __init__(self, input_features, out_features): + super(_DenseLayer, self).__init__() + + ( + self.add_module( + "conv1", + nn.Conv2d( + input_features, + out_features, + kernel_size=3, + stride=1, + padding=2, + bias=True, + ), + ), + ) + (self.add_module("smish1", Smish()),) + self.add_module( + "conv2", + nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True), + ) + + def forward(self, x): + x1, x2 = x + + new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu() + + return 0.5 * (new_features + x2), x2 + + +class _DenseBlock(nn.Sequential): + def __init__(self, num_layers, input_features, out_features): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer(input_features, out_features) + self.add_module("denselayer%d" % (i + 1), layer) + input_features = out_features + + +class UpConvBlock(nn.Module): + def __init__(self, in_features, up_scale): + super(UpConvBlock, self).__init__() + self.up_factor = 2 + self.constant_features = 16 + + layers = self.make_deconv_layers(in_features, up_scale) + assert layers is not None, layers + self.features = nn.Sequential(*layers) + + def make_deconv_layers(self, in_features, up_scale): + layers = [] + all_pads = [0, 0, 1, 3, 7] + for i in range(up_scale): + kernel_size = 2**up_scale + pad = all_pads[up_scale] # kernel_size-1 + out_features = self.compute_out_features(i, up_scale) + layers.append(nn.Conv2d(in_features, out_features, 1)) + layers.append(Smish()) + layers.append( + nn.ConvTranspose2d( + out_features, out_features, kernel_size, stride=2, padding=pad + ) + ) + in_features = out_features + return layers + + def compute_out_features(self, idx, up_scale): + return 1 if idx == up_scale - 1 else self.constant_features + + def forward(self, x): + return self.features(x) + + +class SingleConvBlock(nn.Module): + def __init__(self, in_features, out_features, stride, use_ac=False): + super(SingleConvBlock, self).__init__() + # self.use_bn = use_bs + self.use_ac = use_ac + self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True) + if self.use_ac: + self.smish = Smish() + + def forward(self, x): + x = self.conv(x) + if self.use_ac: + return self.smish(x) + else: + return x + + +class DoubleConvBlock(nn.Module): + def __init__( + self, in_features, mid_features, out_features=None, stride=1, use_act=True + ): + super(DoubleConvBlock, self).__init__() + + self.use_act = use_act + if out_features is None: + out_features = mid_features + self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1) + self.smish = Smish() # nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.smish(x) + x = self.conv2(x) + if self.use_act: + x = self.smish(x) + return x + + +class TED(nn.Module): + """Definition of Tiny and Efficient Edge Detector + model + """ + + def __init__(self): + super(TED, self).__init__() + self.block_1 = DoubleConvBlock( + 3, + 16, + 16, + stride=2, + ) + self.block_2 = DoubleConvBlock(16, 32, use_act=False) + self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # skip1 connection, see fig. 2 + self.side_1 = SingleConvBlock(16, 32, 2) + + # skip2 connection, see fig. 2 + self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1) + + # USNet + self.up_block_1 = UpConvBlock(16, 1) + self.up_block_2 = UpConvBlock(32, 1) + self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1) + + self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion + + self.apply(weight_init) + + def slice(self, tensor, slice_shape): + t_shape = tensor.shape + img_h, img_w = slice_shape + if img_w != t_shape[-1] or img_h != t_shape[2]: + new_tensor = F.interpolate( + tensor, size=(img_h, img_w), mode="bicubic", align_corners=False + ) + + else: + new_tensor = tensor + # tensor[..., :height, :width] + return new_tensor + + def resize_input(self, tensor): + t_shape = tensor.shape + if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0: + img_w = ((t_shape[3] // 8) + 1) * 8 + img_h = ((t_shape[2] // 8) + 1) * 8 + new_tensor = F.interpolate( + tensor, size=(img_h, img_w), mode="bicubic", align_corners=False + ) + else: + new_tensor = tensor + return new_tensor + + def crop_bdcn(data1, h, w, crop_h, crop_w): + # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN + _, _, h1, w1 = data1.size() + assert h <= h1 and w <= w1 + data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w] + return data + + def forward(self, x, single_test=False): + assert x.ndim == 4, x.shape + # supose the image size is 352x352 + + # Block 1 + block_1 = self.block_1(x) # [8,16,176,176] + block_1_side = self.side_1(block_1) # 16 [8,32,88,88] + + # Block 2 + block_2 = self.block_2(block_1) # 32 # [8,32,176,176] + block_2_down = self.maxpool(block_2) # [8,32,88,88] + block_2_add = block_2_down + block_1_side # [8,32,88,88] + + # Block 3 + block_3_pre_dense = self.pre_dense_3( + block_2_down + ) # [8,64,88,88] block 3 L connection + block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88] + + # upsampling blocks + out_1 = self.up_block_1(block_1) + out_2 = self.up_block_2(block_2) + out_3 = self.up_block_3(block_3) + + results = [out_1, out_2, out_3] + + # concatenate multiscale outputs + block_cat = torch.cat(results, dim=1) # Bx6xHxW + block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion + + results.append(block_cat) + return results + + +if __name__ == "__main__": + batch_size = 8 + img_height = 352 + img_width = 352 + + # device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cpu" + input = torch.rand(batch_size, 3, img_height, img_width).to(device) + # target = torch.rand(batch_size, 1, img_height, img_width).to(device) + print(f"input shape: {input.shape}") + model = TED().to(device) + output = model(input) + print(f"output shapes: {[t.shape for t in output]}") + + # for i in range(20000): + # print(i) + # output = model(input) + # loss = nn.MSELoss()(output[-1], target) + # loss.backward() diff --git a/controlnet_aux/tests/requirements.txt b/controlnet_aux/tests/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/tests/test_image.png b/controlnet_aux/tests/test_image.png new file mode 100644 index 0000000000000000000000000000000000000000..c4a751e31da45af83c8a3d5ec02cf8c22c7bb8e9 Binary files /dev/null and b/controlnet_aux/tests/test_image.png differ diff --git a/controlnet_aux/tests/test_processor.py b/controlnet_aux/tests/test_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..cca7e16f91ed52764f15d52cf374ab0050d12fab --- /dev/null +++ b/controlnet_aux/tests/test_processor.py @@ -0,0 +1,89 @@ +"""Test the Processor class.""" +import unittest +from PIL import Image + +from controlnet_aux.processor import Processor + + +class TestProcessor(unittest.TestCase): + def test_hed(self): + processor = Processor('hed') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_midas(self): + processor = Processor('midas') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_mlsd(self): + processor = Processor('mlsd') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_openpose(self): + processor = Processor('openpose') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_pidinet(self): + processor = Processor('pidinet') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_normalbae(self): + processor = Processor('normalbae') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_lineart(self): + processor = Processor('lineart') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_lineart_coarse(self): + processor = Processor('lineart_coarse') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_lineart_anime(self): + processor = Processor('lineart_anime') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_canny(self): + processor = Processor('canny') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_content_shuffle(self): + processor = Processor('content_shuffle') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_zoe(self): + processor = Processor('zoe') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_mediapipe_face(self): + processor = Processor('mediapipe_face') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/controlnet_aux/tests/test_processor_pytest.py b/controlnet_aux/tests/test_processor_pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..c70c13a1af969220b5d681288610921ca3880992 --- /dev/null +++ b/controlnet_aux/tests/test_processor_pytest.py @@ -0,0 +1,77 @@ +import io + +import numpy as np +import pytest +from PIL import Image + +from controlnet_aux.processor import MODELS, Processor + + +@pytest.fixture(params=[ + 'scribble_hed', + 'softedge_hed', + 'scribble_hedsafe', + 'softedge_hedsafe', + 'depth_midas', + 'mlsd', + 'openpose', + 'openpose_hand', + 'openpose_face', + 'openpose_faceonly', + 'openpose_full', + 'scribble_pidinet', + 'softedge_pidinet', + 'scribble_pidsafe', + 'softedge_pidsafe', + 'normal_bae', + 'lineart_coarse', + 'lineart_realistic', + 'lineart_anime', + 'canny', + 'shuffle', + 'depth_zoe', + 'depth_leres', + 'depth_leres++', + 'mediapipe_face' +]) +def processor(request): + return Processor(request.param) + + +def test_processor_init(processor): + assert isinstance(processor.processor, MODELS[processor.processor_id]['class']) + assert isinstance(processor.params, dict) + + +def test_processor_call(processor): + # Load test image + with open('test_image.png', 'rb') as f: + image_bytes = f.read() + image = Image.open(io.BytesIO(image_bytes)) + + # Output size + resolution = 512 + W, H = image.size + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + + # Test processing + processed_image = processor(image) + assert isinstance(processed_image, Image.Image) + assert processed_image.size == (W, H) + + +def test_processor_call_bytes(processor): + # Load test image + with open('test_image.png', 'rb') as f: + image_bytes = f.read() + + # Test processing + processed_image_bytes = processor(image_bytes, to_pil=False) + assert isinstance(processed_image_bytes, bytes) + assert len(processed_image_bytes) > 0 \ No newline at end of file diff --git a/controlnet_aux/util.py b/controlnet_aux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..79ba7f120cc60bf50c849e6f4abab684b06bf388 --- /dev/null +++ b/controlnet_aux/util.py @@ -0,0 +1,146 @@ +import os +import random + +import cv2 +import numpy as np +import torch + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def ade_palette(): + """ADE20K palette that maps each class to RGB values.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + diff --git a/controlnet_aux/zoe/LICENSE b/controlnet_aux/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/controlnet_aux/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/controlnet_aux/zoe/__init__.py b/controlnet_aux/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18fb5f0433af55d85b7761aad4110ba13beef841 --- /dev/null +++ b/controlnet_aux/zoe/__init__.py @@ -0,0 +1,84 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from ..util import HWC3, resize_image +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.models.zoedepth_nk.zoedepth_nk_v1 import ZoeDepthNK +from .zoedepth.utils.config import get_config + + +class ZoeDetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, model_type="zoedepth", filename=None, cache_dir=None, local_files_only=False): + filename = filename or "ZoeD_M12_N.pt" + + if os.path.isdir(pretrained_model_or_path): + model_path = os.path.join(pretrained_model_or_path, filename) + else: + model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) + + conf = get_config(model_type, "infer") + model_cls = ZoeDepth if model_type == "zoedepth" else ZoeDepthNK + model = model_cls.build_from_config(conf) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model']) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type=None, gamma_corrected=False): + device = next(iter(self.model.parameters())).device + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + input_image = HWC3(input_image) + input_image = resize_image(input_image, detect_resolution) + + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().to(device) + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + + if gamma_corrected: + depth = np.power(depth, 2.2) + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = depth_image + detected_map = HWC3(detected_map) + + img = resize_image(input_image, image_resolution) + H, W, C = img.shape + + detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/controlnet_aux/zoe/zoedepth/__init__.py b/controlnet_aux/zoe/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/zoe/zoedepth/models/__init__.py b/controlnet_aux/zoe/zoedepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/__init__.py b/controlnet_aux/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 0000000000000000000000000000000000000000..1af551be93b94ff4bd64c909ffdec7eeb17665ef --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,379 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + # print("Params passed to Resize transform:") + # print("\twidth: ", width) + # print("\theight: ", height) + # print("\tresize_target: ", resize_target) + # print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + # print("\tensure_multiple_of: ", ensure_multiple_of) + # print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + # print("img_size", img_size) + midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/LICENSE b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/README.md b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9568ea71c755b6938ee5482ba9f09be722e75943 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/README.md @@ -0,0 +1,259 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + + +and our [preprint](https://arxiv.org/abs/2103.13413): + +> Vision Transformers for Dense Prediction +> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun + + +MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with +multi-objective optimization. +The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). +The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. + +![](figures/Improvement_vs_FPS.png) + +### Setup + +1) Pick one or more models and download the corresponding weights to the `weights` folder: + +MiDaS 3.1 +- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) +- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) +- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) +- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) + +MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) + +MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) + +1) Set up dependencies: + + ```shell + conda env create -f environment.yaml + conda activate midas-py310 + ``` + +#### optional + +For the Next-ViT model, execute + +```shell +git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit +``` + +For the OpenVINO model, install + +```shell +pip install openvino +``` + +### Usage + +1) Place one or more input images in the folder `input`. + +2) Run the model with + + ```shell + python run.py --model_type --input_path input --output_path output + ``` + where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), + [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), + [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), + [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), + [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). + +3) The resulting depth maps are written to the `output` folder. + +#### optional + +1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This + size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single + inference height but a range of different heights. Feel free to explore different heights by appending the extra + command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may + decrease the model accuracy. +2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is + supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + +#### via Camera + + If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths + away and choose a model type as shown above: + + ```shell + python run.py --model_type --side + ``` + + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + side-by-side for comparison. + +#### via Docker + +1) Make sure you have installed Docker and the + [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). + +2) Build the Docker image: + + ```shell + docker build -t midas . + ``` + +3) Run inference: + + ```shell + docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas + ``` + + This command passes through all of your NVIDIA GPUs to the container, mounts the + `input` and `output` directories and then runs the inference. + +#### via PyTorch Hub + +The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) + +#### via TensorFlow or ONNX + +See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. + +Currently only supports MiDaS v2.1. + + +#### via Mobile (iOS / Android) + +See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. + +#### via ROS1 (Robot Operating System) + +See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. + +Currently only supports MiDaS v2.1. DPT-based models to be added. + + +### Accuracy + +We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets +(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. +$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to +MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by +the numbers in the model names. The table also shows the **number of parameters** (in millions) and the +**frames per second** for inference at the training resolution (for GPU RTX 3090): + +| MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | +|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| +| **Inference height 512** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | +| | | | | | | | | | | +| **Inference height 384** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | +| [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | +| [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | +| [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | +| [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | +| [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | +| [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | +| [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | +| [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | +| [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | +| | | | | | | | | | | +| **Inference height 256** | | | | | | | | | | +| [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | +| [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | +| | | | | | | | | | | +| **Inference height 224** | | | | | | | | | | +| [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | + +* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ +$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model +does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other +validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the +improvement, because these quantities are averages over the pixels of an image and do not take into account the +advantage of more details due to a higher resolution.\ +Best values per column and same validation height in bold + +#### Improvement + +The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 +DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then +the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. + +Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the +improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 +and v2.0 Large384 respectively instead of v3.0 DPTL-384. + +### Depth map comparison + +Zoom in for better visibility +![](figures/Comparison.png) + +### Speed on Camera Feed + +Test configuration +- Windows 10 +- 11th Gen Intel Core i7-1185G7 3.00GHz +- 16GB RAM +- Camera resolution 640x480 +- openvino_midas_v21_small_256 + +Speed: 22 FPS + +### Changelog + +* [Dec 2022] Released MiDaS v3.1: + - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) + - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split + - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 + - Integrated live depth estimation from camera feed +* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). +* [Apr 2021] Released MiDaS v3.0: + - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 + - Additional models can be found [here](https://github.com/isl-org/DPT) +* [Nov 2020] Released MiDaS v2.1: + - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) + - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. + - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) + - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots +* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). +* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust +* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@ARTICLE {Ranftl2022, + author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", + title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", + journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", + year = "2022", + volume = "44", + number = "3" +} +``` + +If you use a DPT-based model, please also cite: + +``` +@article{Ranftl2021, + author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, + title = {Vision Transformers for Dense Prediction}, + journal = {ICCV}, + year = {2021}, +} +``` + +### Acknowledgements + +Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). +We'd like to thank the authors for making these libraries available. + +### License + +MIT License diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/__init__.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/hubconf.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d638be5151c4e305daff0c47d1ea3fc8066377d --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/hubconf.py @@ -0,0 +1,435 @@ +dependencies = ["torch"] + +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small + +def DPT_BEiT_L_512(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_512 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_512", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2l24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_T_256(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_T_256 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2t16_256", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Next_ViT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="next_vit_large_6m", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_LeViT_224(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_LeViT_224 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Large(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Large model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Hybrid(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Hybrid model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitb_rn50_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet() + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS_small(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + + +def transforms(): + import cv2 + from torchvision.transforms import Compose + from midas.transforms import Resize, NormalizeImage, PrepareForNet + from midas import transforms + + transforms.default_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.small_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.dpt_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + return transforms diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..7a24e02cd2b979844bf638b46ac60949ee9ce691 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py @@ -0,0 +1,196 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[0, 4, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py @@ -0,0 +1,106 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py @@ -0,0 +1,39 @@ +import timm + +import torch.nn as nn + +from pathlib import Path +from .utils import activations, forward_default, get_activation + +from ..external.next_vit.classification.nextvit import * + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=[2, 6, 36, 39], +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks == None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py new file mode 100644 index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py @@ -0,0 +1,52 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=[1, 1, 17, 1], + patch_grid=[96, 96] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py @@ -0,0 +1,249 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def make_backbone_default( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + patch_size=[16, 16], + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + value = nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py @@ -0,0 +1,439 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3129d09cb43a7c79b23916236991fabbedb78f55 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cd1f2d43054bfd3d650587c7b2ed35f1347c9e --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + if not "openvino" in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if not "openvino" in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if not "openvino" in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if not "openvino" in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/controlnet_aux/zoe/zoedepth/models/builder.py b/controlnet_aux/zoe/zoedepth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = getattr(module, "get_version") + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/controlnet_aux/zoe/zoedepth/models/depth_model.py b/controlnet_aux/zoe/zoedepth/models/depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self): + super().__init__() + self.device = 'cpu' + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + @torch.no_grad() + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") + \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/layers/__init__.py b/controlnet_aux/zoe/zoedepth/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c344f725c8a10dcaf29d4c308eb49d86ac51ff88 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/layers/__init__.py @@ -0,0 +1,23 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat diff --git a/controlnet_aux/zoe/zoedepth/models/layers/attractor.py b/controlnet_aux/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py b/controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/controlnet_aux/zoe/zoedepth/models/layers/localbins_layers.py b/controlnet_aux/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py b/controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/controlnet_aux/zoe/zoedepth/models/model_io.py b/controlnet_aux/zoe/zoedepth/models/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/model_io.py @@ -0,0 +1,92 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") + \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py b/controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/controlnet_aux/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,250 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py b/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..568ac512ae0462c499cbf424eca41bfc2328bc16 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,332 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..layers.patch_transformer import PatchTransformerEncoder +from ..model_io import load_state_from_resource + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError: + raise ValueError( + f"bin_conf_name {bin_conf_name} not found in bin_confs") + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/controlnet_aux/zoe/zoedepth/utils/__init__.py b/controlnet_aux/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/controlnet_aux/zoe/zoedepth/utils/arg_utils.py b/controlnet_aux/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/controlnet_aux/zoe/zoedepth/utils/config.py b/controlnet_aux/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py b/controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file