Spaces:
Configuration error
Configuration error
""" helper function | |
author junde | |
""" | |
import collections | |
import logging | |
import math | |
import os | |
import pathlib | |
import random | |
import shutil | |
import sys | |
import tempfile | |
import time | |
import warnings | |
from collections import OrderedDict | |
from datetime import datetime | |
from typing import BinaryIO, List, Optional, Text, Tuple, Union | |
import dateutil.tz | |
import matplotlib.pyplot as plt | |
import numpy | |
import numpy as np | |
import PIL | |
import seaborn as sns | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
from monai.config import print_config | |
from monai.data import (CacheDataset, ThreadDataLoader, decollate_batch, | |
load_decathlon_datalist, set_track_meta) | |
from monai.inferers import sliding_window_inference | |
from monai.losses import DiceCELoss | |
from monai.metrics import DiceMetric | |
from monai.networks.nets import SwinUNETR | |
from monai.transforms import (AsDiscrete, Compose, CropForegroundd, | |
EnsureTyped, LoadImaged, Orientationd, | |
RandCropByPosNegLabeld, RandFlipd, RandRotate90d, | |
RandShiftIntensityd, ScaleIntensityRanged, | |
Spacingd) | |
from PIL import Image, ImageColor, ImageDraw, ImageFont | |
from torch import autograd | |
from torch.autograd import Function, Variable | |
from torch.optim.lr_scheduler import _LRScheduler | |
from torch.utils.data import DataLoader | |
# from lucent.optvis.param.spatial import pixel_image, fft_image, init_image | |
# from lucent.optvis.param.color import to_valid_rgb | |
# from lucent.optvis import objectives, transform, param | |
# from lucent.misc.io import show | |
from torchvision.models import vgg19 | |
from tqdm import tqdm | |
import cfg | |
# from precpt import run_precpt | |
from models.discriminator import Discriminator | |
# from siren_pytorch import SirenNet, SirenWrapper | |
args = cfg.parse_args() | |
device = torch.device('cuda', args.gpu_device) | |
'''preparation of domain loss''' | |
# cnn = vgg19(pretrained=True).features.to(device).eval() | |
# cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
# cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
# netD = Discriminator(1).to(device) | |
# netD.apply(init_D) | |
# beta1 = 0.5 | |
# dis_lr = 0.0002 | |
# optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) | |
'''end''' | |
def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True): | |
""" return given network | |
""" | |
if net == 'sam': | |
from models.sam import SamPredictor, sam_model_registry | |
from models.sam.utils.transforms import ResizeLongestSide | |
options = ['default','vit_b','vit_l','vit_h'] | |
if args.encoder not in options: | |
raise ValueError("Invalid encoder option. Please choose from: {}".format(options)) | |
else: | |
net = sam_model_registry[args.encoder](args,checkpoint=args.sam_ckpt).to(device) | |
elif net == 'efficient_sam': | |
from models.efficient_sam import sam_model_registry | |
options = ['default','vit_s','vit_t'] | |
if args.encoder not in options: | |
raise ValueError("Invalid encoder option. Please choose from: {}".format(options)) | |
else: | |
net = sam_model_registry[args.encoder](args) | |
elif net == 'mobile_sam': | |
from models.MobileSAMv2.mobilesamv2 import sam_model_registry | |
options = ['default','vit_h','vit_l','vit_b','tiny_vit','efficientvit_l2','PromptGuidedDecoder','sam_vit_h'] | |
if args.encoder not in options: | |
raise ValueError("Invalid encoder option. Please choose from: {}".format(options)) | |
else: | |
net = sam_model_registry[args.encoder](args,checkpoint=args.sam_ckpt) | |
else: | |
print('the network name you have entered is not supported yet') | |
sys.exit() | |
if use_gpu: | |
#net = net.cuda(device = gpu_device) | |
if distribution != 'none': | |
net = torch.nn.DataParallel(net,device_ids=[int(id) for id in args.distributed.split(',')]) | |
net = net.to(device=gpu_device) | |
else: | |
net = net.to(device=gpu_device) | |
return net | |
def get_decath_loader(args): | |
train_transforms = Compose( | |
[ | |
LoadImaged(keys=["image", "label"], ensure_channel_first=True), | |
ScaleIntensityRanged( | |
keys=["image"], | |
a_min=-175, | |
a_max=250, | |
b_min=0.0, | |
b_max=1.0, | |
clip=True, | |
), | |
CropForegroundd(keys=["image", "label"], source_key="image"), | |
Orientationd(keys=["image", "label"], axcodes="RAS"), | |
Spacingd( | |
keys=["image", "label"], | |
pixdim=(1.5, 1.5, 2.0), | |
mode=("bilinear", "nearest"), | |
), | |
EnsureTyped(keys=["image", "label"], device=device, track_meta=False), | |
RandCropByPosNegLabeld( | |
keys=["image", "label"], | |
label_key="label", | |
spatial_size=(args.roi_size, args.roi_size, args.chunk), | |
pos=1, | |
neg=1, | |
num_samples=args.num_sample, | |
image_key="image", | |
image_threshold=0, | |
), | |
RandFlipd( | |
keys=["image", "label"], | |
spatial_axis=[0], | |
prob=0.10, | |
), | |
RandFlipd( | |
keys=["image", "label"], | |
spatial_axis=[1], | |
prob=0.10, | |
), | |
RandFlipd( | |
keys=["image", "label"], | |
spatial_axis=[2], | |
prob=0.10, | |
), | |
RandRotate90d( | |
keys=["image", "label"], | |
prob=0.10, | |
max_k=3, | |
), | |
RandShiftIntensityd( | |
keys=["image"], | |
offsets=0.10, | |
prob=0.50, | |
), | |
] | |
) | |
val_transforms = Compose( | |
[ | |
LoadImaged(keys=["image", "label"], ensure_channel_first=True), | |
ScaleIntensityRanged( | |
keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True | |
), | |
CropForegroundd(keys=["image", "label"], source_key="image"), | |
Orientationd(keys=["image", "label"], axcodes="RAS"), | |
Spacingd( | |
keys=["image", "label"], | |
pixdim=(1.5, 1.5, 2.0), | |
mode=("bilinear", "nearest"), | |
), | |
EnsureTyped(keys=["image", "label"], device=device, track_meta=True), | |
] | |
) | |
data_dir = args.data_path | |
split_JSON = "dataset_0.json" | |
datasets = os.path.join(data_dir, split_JSON) | |
datalist = load_decathlon_datalist(datasets, True, "training") | |
val_files = load_decathlon_datalist(datasets, True, "validation") | |
train_ds = CacheDataset( | |
data=datalist, | |
transform=train_transforms, | |
cache_num=24, | |
cache_rate=1.0, | |
num_workers=8, | |
) | |
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.b, shuffle=True) | |
val_ds = CacheDataset( | |
data=val_files, transform=val_transforms, cache_num=2, cache_rate=1.0, num_workers=0 | |
) | |
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) | |
set_track_meta(False) | |
return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files | |
def cka_loss(gram_featureA, gram_featureB): | |
scaled_hsic = torch.dot(torch.flatten(gram_featureA),torch.flatten(gram_featureB)) | |
normalization_x = gram_featureA.norm() | |
normalization_y = gram_featureB.norm() | |
return scaled_hsic / (normalization_x * normalization_y) | |
class WarmUpLR(_LRScheduler): | |
"""warmup_training learning rate scheduler | |
Args: | |
optimizer: optimzier(e.g. SGD) | |
total_iters: totoal_iters of warmup phase | |
""" | |
def __init__(self, optimizer, total_iters, last_epoch=-1): | |
self.total_iters = total_iters | |
super().__init__(optimizer, last_epoch) | |
def get_lr(self): | |
"""we will use the first m batches, and set the learning | |
rate to base_lr * m / total_iters | |
""" | |
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] | |
def gram_matrix(input): | |
a, b, c, d = input.size() # a=batch size(=1) | |
# b=number of feature maps | |
# (c,d)=dimensions of a f. map (N=c*d) | |
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL | |
G = torch.mm(features, features.t()) # compute the gram product | |
# we 'normalize' the values of the gram matrix | |
# by dividing by the number of element in each feature maps. | |
return G.div(a * b * c * d) | |
def make_grid( | |
tensor: Union[torch.Tensor, List[torch.Tensor]], | |
nrow: int = 8, | |
padding: int = 2, | |
normalize: bool = False, | |
value_range: Optional[Tuple[int, int]] = None, | |
scale_each: bool = False, | |
pad_value: int = 0, | |
**kwargs | |
) -> torch.Tensor: | |
if not (torch.is_tensor(tensor) or | |
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): | |
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') | |
if "range" in kwargs.keys(): | |
warning = "range will be deprecated, please use value_range instead." | |
warnings.warn(warning) | |
value_range = kwargs["range"] | |
# if list of tensors, convert to a 4D mini-batch Tensor | |
if isinstance(tensor, list): | |
tensor = torch.stack(tensor, dim=0) | |
if tensor.dim() == 2: # single image H x W | |
tensor = tensor.unsqueeze(0) | |
if tensor.dim() == 3: # single image | |
if tensor.size(0) == 1: # if single-channel, convert to 3-channel | |
tensor = torch.cat((tensor, tensor, tensor), 0) | |
tensor = tensor.unsqueeze(0) | |
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images | |
tensor = torch.cat((tensor, tensor, tensor), 1) | |
if normalize is True: | |
tensor = tensor.clone() # avoid modifying tensor in-place | |
if value_range is not None: | |
assert isinstance(value_range, tuple), \ | |
"value_range has to be a tuple (min, max) if specified. min and max are numbers" | |
def norm_ip(img, low, high): | |
img.clamp(min=low, max=high) | |
img.sub_(low).div_(max(high - low, 1e-5)) | |
def norm_range(t, value_range): | |
if value_range is not None: | |
norm_ip(t, value_range[0], value_range[1]) | |
else: | |
norm_ip(t, float(t.min()), float(t.max())) | |
if scale_each is True: | |
for t in tensor: # loop over mini-batch dimension | |
norm_range(t, value_range) | |
else: | |
norm_range(tensor, value_range) | |
if tensor.size(0) == 1: | |
return tensor.squeeze(0) | |
# make the mini-batch of images into a grid | |
nmaps = tensor.size(0) | |
xmaps = min(nrow, nmaps) | |
ymaps = int(math.ceil(float(nmaps) / xmaps)) | |
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) | |
num_channels = tensor.size(1) | |
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) | |
k = 0 | |
for y in range(ymaps): | |
for x in range(xmaps): | |
if k >= nmaps: | |
break | |
# Tensor.copy_() is a valid method but seems to be missing from the stubs | |
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ | |
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] | |
2, x * width + padding, width - padding | |
).copy_(tensor[k]) | |
k = k + 1 | |
return grid | |
def save_image( | |
tensor: Union[torch.Tensor, List[torch.Tensor]], | |
fp: Union[Text, pathlib.Path, BinaryIO], | |
format: Optional[str] = None, | |
**kwargs | |
) -> None: | |
""" | |
Save a given Tensor into an image file. | |
Args: | |
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, | |
saves the tensor as a grid of images by calling ``make_grid``. | |
fp (string or file object): A filename or a file object | |
format(Optional): If omitted, the format to use is determined from the filename extension. | |
If a file object was used instead of a filename, this parameter should always be used. | |
**kwargs: Other arguments are documented in ``make_grid``. | |
""" | |
grid = make_grid(tensor, **kwargs) | |
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() | |
im = Image.fromarray(ndarr) | |
im.save(fp, format=format) | |
def create_logger(log_dir, phase='train'): | |
time_str = time.strftime('%Y-%m-%d-%H-%M') | |
log_file = '{}_{}.log'.format(time_str, phase) | |
final_log_file = os.path.join(log_dir, log_file) | |
head = '%(asctime)-15s %(message)s' | |
logging.basicConfig(filename=str(final_log_file), | |
format=head) | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
console = logging.StreamHandler() | |
logging.getLogger('').addHandler(console) | |
return logger | |
def set_log_dir(root_dir, exp_name): | |
path_dict = {} | |
os.makedirs(root_dir, exist_ok=True) | |
# set log path | |
exp_path = os.path.join(root_dir, exp_name) | |
now = datetime.now(dateutil.tz.tzlocal()) | |
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') | |
prefix = exp_path + '_' + timestamp | |
os.makedirs(prefix) | |
path_dict['prefix'] = prefix | |
# set checkpoint path | |
ckpt_path = os.path.join(prefix, 'Model') | |
os.makedirs(ckpt_path) | |
path_dict['ckpt_path'] = ckpt_path | |
log_path = os.path.join(prefix, 'Log') | |
os.makedirs(log_path) | |
path_dict['log_path'] = log_path | |
# set sample image path for fid calculation | |
sample_path = os.path.join(prefix, 'Samples') | |
os.makedirs(sample_path) | |
path_dict['sample_path'] = sample_path | |
return path_dict | |
def save_checkpoint(states, is_best, output_dir, | |
filename='checkpoint.pth'): | |
torch.save(states, os.path.join(output_dir, filename)) | |
if is_best: | |
torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) | |
class RunningStats: | |
def __init__(self, WIN_SIZE): | |
self.mean = 0 | |
self.run_var = 0 | |
self.WIN_SIZE = WIN_SIZE | |
self.window = collections.deque(maxlen=WIN_SIZE) | |
def clear(self): | |
self.window.clear() | |
self.mean = 0 | |
self.run_var = 0 | |
def is_full(self): | |
return len(self.window) == self.WIN_SIZE | |
def push(self, x): | |
if len(self.window) == self.WIN_SIZE: | |
# Adjusting variance | |
x_removed = self.window.popleft() | |
self.window.append(x) | |
old_m = self.mean | |
self.mean += (x - x_removed) / self.WIN_SIZE | |
self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed) | |
else: | |
# Calculating first variance | |
self.window.append(x) | |
delta = x - self.mean | |
self.mean += delta / len(self.window) | |
self.run_var += delta * (x - self.mean) | |
def get_mean(self): | |
return self.mean if len(self.window) else 0.0 | |
def get_var(self): | |
return self.run_var / len(self.window) if len(self.window) > 1 else 0.0 | |
def get_std(self): | |
return math.sqrt(self.get_var()) | |
def get_all(self): | |
return list(self.window) | |
def __str__(self): | |
return "Current window values: {}".format(list(self.window)) | |
def iou(outputs: np.array, labels: np.array): | |
SMOOTH = 1e-6 | |
intersection = (outputs & labels).sum((1, 2)) | |
union = (outputs | labels).sum((1, 2)) | |
iou = (intersection + SMOOTH) / (union + SMOOTH) | |
return iou.mean() | |
class DiceCoeff(Function): | |
"""Dice coeff for individual examples""" | |
def forward(self, input, target): | |
self.save_for_backward(input, target) | |
eps = 0.0001 | |
self.inter = torch.dot(input.view(-1), target.view(-1)) | |
self.union = torch.sum(input) + torch.sum(target) + eps | |
t = (2 * self.inter.float() + eps) / self.union.float() | |
return t | |
# This function has only a single output, so it gets only one gradient | |
def backward(self, grad_output): | |
input, target = self.saved_variables | |
grad_input = grad_target = None | |
if self.needs_input_grad[0]: | |
grad_input = grad_output * 2 * (target * self.union - self.inter) \ | |
/ (self.union * self.union) | |
if self.needs_input_grad[1]: | |
grad_target = None | |
return grad_input, grad_target | |
def dice_coeff(input, target): | |
"""Dice coeff for batches""" | |
if input.is_cuda: | |
s = torch.FloatTensor(1).to(device = input.device).zero_() | |
else: | |
s = torch.FloatTensor(1).zero_() | |
for i, c in enumerate(zip(input, target)): | |
s = s + DiceCoeff().forward(c[0], c[1]) | |
return s / (i + 1) | |
'''parameter''' | |
def para_image(w, h=None, img = None, mode = 'multi', seg = None, sd=None, batch=None, | |
fft = False, channels=None, init = None): | |
h = h or w | |
batch = batch or 1 | |
ch = channels or 3 | |
shape = [batch, ch, h, w] | |
param_f = fft_image if fft else pixel_image | |
if init is not None: | |
param_f = init_image | |
params, maps_f = param_f(init) | |
else: | |
params, maps_f = param_f(shape, sd=sd) | |
if mode == 'multi': | |
output = to_valid_out(maps_f,img,seg) | |
elif mode == 'seg': | |
output = gene_out(maps_f,img) | |
elif mode == 'raw': | |
output = raw_out(maps_f,img) | |
return params, output | |
def to_valid_out(maps_f,img,seg): #multi-rater | |
def inner(): | |
maps = maps_f() | |
maps = maps.to(device = img.device) | |
maps = torch.nn.Softmax(dim = 1)(maps) | |
final_seg = torch.multiply(seg,maps).sum(dim = 1, keepdim = True) | |
return torch.cat((img,final_seg),1) | |
# return torch.cat((img,maps),1) | |
return inner | |
def gene_out(maps_f,img): #pure seg | |
def inner(): | |
maps = maps_f() | |
maps = maps.to(device = img.device) | |
# maps = torch.nn.Sigmoid()(maps) | |
return torch.cat((img,maps),1) | |
# return torch.cat((img,maps),1) | |
return inner | |
def raw_out(maps_f,img): #raw | |
def inner(): | |
maps = maps_f() | |
maps = maps.to(device = img.device) | |
# maps = torch.nn.Sigmoid()(maps) | |
return maps | |
# return torch.cat((img,maps),1) | |
return inner | |
class CompositeActivation(torch.nn.Module): | |
def forward(self, x): | |
x = torch.atan(x) | |
return torch.cat([x/0.67, (x*x)/0.6], 1) | |
# return x | |
def cppn(args, size, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8, | |
activation_fn=CompositeActivation, normalize=False, device = "cuda:0"): | |
r = 3 ** 0.5 | |
coord_range = torch.linspace(-r, r, size) | |
x = coord_range.view(-1, 1).repeat(1, coord_range.size(0)) | |
y = coord_range.view(1, -1).repeat(coord_range.size(0), 1) | |
input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).repeat(batch,1,1,1).to(device) | |
layers = [] | |
kernel_size = 1 | |
for i in range(num_layers): | |
out_c = num_hidden_channels | |
in_c = out_c * 2 # * 2 for composite activation | |
if i == 0: | |
in_c = 2 | |
if i == num_layers - 1: | |
out_c = num_output_channels | |
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size))) | |
if normalize: | |
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c))) | |
if i < num_layers - 1: | |
layers.append(('actv{}'.format(i), activation_fn())) | |
else: | |
layers.append(('output', torch.nn.Sigmoid())) | |
# Initialize model | |
net = torch.nn.Sequential(OrderedDict(layers)).to(device) | |
# Initialize weights | |
def weights_init(module): | |
if isinstance(module, torch.nn.Conv2d): | |
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels)) | |
if module.bias is not None: | |
torch.nn.init.zeros_(module.bias) | |
net.apply(weights_init) | |
# Set last conv2d layer's weights to 0 | |
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight) | |
outimg = raw_out(lambda: net(input_tensor),img) if args.netype == 'raw' else to_valid_out(lambda: net(input_tensor),img,seg) | |
return net.parameters(), outimg | |
def get_siren(args): | |
wrapper = get_network(args, 'siren', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed) | |
'''load init weights''' | |
checkpoint = torch.load('./logs/siren_train_init_2022_08_19_21_00_16/Model/checkpoint_best.pth') | |
wrapper.load_state_dict(checkpoint['state_dict'],strict=False) | |
'''end''' | |
'''load prompt''' | |
checkpoint = torch.load('./logs/vae_standard_refuge1_2022_08_21_17_56_49/Model/checkpoint500') | |
vae = get_network(args, 'vae', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed) | |
vae.load_state_dict(checkpoint['state_dict'],strict=False) | |
'''end''' | |
return wrapper, vae | |
def siren(args, wrapper, vae, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8, | |
activation_fn=CompositeActivation, normalize=False, device = "cuda:0"): | |
vae_img = torchvision.transforms.Resize(64)(img) | |
latent = vae.encoder(vae_img).view(-1).detach() | |
outimg = raw_out(lambda: wrapper(latent = latent),img) if args.netype == 'raw' else to_valid_out(lambda: wrapper(latent = latent),img,seg) | |
# img = torch.randn(1, 3, 256, 256) | |
# loss = wrapper(img) | |
# loss.backward() | |
# # after much training ... | |
# # simply invoke the wrapper without passing in anything | |
# pred_img = wrapper() # (1, 3, 256, 256) | |
return wrapper.parameters(), outimg | |
'''adversary''' | |
def render_vis( | |
args, | |
model, | |
objective_f, | |
real_img, | |
param_f=None, | |
optimizer=None, | |
transforms=None, | |
thresholds=(256,), | |
verbose=True, | |
preprocess=True, | |
progress=True, | |
show_image=True, | |
save_image=False, | |
image_name=None, | |
show_inline=False, | |
fixed_image_size=None, | |
label = 1, | |
raw_img = None, | |
prompt = None | |
): | |
if label == 1: | |
sign = 1 | |
elif label == 0: | |
sign = -1 | |
else: | |
print('label is wrong, label is',label) | |
if args.reverse: | |
sign = -sign | |
if args.multilayer: | |
sign = 1 | |
'''prepare''' | |
now = datetime.now() | |
date_time = now.strftime("%m-%d-%Y, %H:%M:%S") | |
netD, optD = pre_d() | |
'''end''' | |
if param_f is None: | |
param_f = lambda: param.image(128) | |
# param_f is a function that should return two things | |
# params - parameters to update, which we pass to the optimizer | |
# image_f - a function that returns an image as a tensor | |
params, image_f = param_f() | |
if optimizer is None: | |
optimizer = lambda params: torch.optim.Adam(params, lr=5e-1) | |
optimizer = optimizer(params) | |
if transforms is None: | |
transforms = [] | |
transforms = transforms.copy() | |
# Upsample images smaller than 224 | |
image_shape = image_f().shape | |
if fixed_image_size is not None: | |
new_size = fixed_image_size | |
elif image_shape[2] < 224 or image_shape[3] < 224: | |
new_size = 224 | |
else: | |
new_size = None | |
if new_size: | |
transforms.append( | |
torch.nn.Upsample(size=new_size, mode="bilinear", align_corners=True) | |
) | |
transform_f = transform.compose(transforms) | |
hook = hook_model(model, image_f) | |
objective_f = objectives.as_objective(objective_f) | |
if verbose: | |
model(transform_f(image_f())) | |
print("Initial loss of ad: {:.3f}".format(objective_f(hook))) | |
images = [] | |
try: | |
for i in tqdm(range(1, max(thresholds) + 1), disable=(not progress)): | |
optimizer.zero_grad() | |
try: | |
model(transform_f(image_f())) | |
except RuntimeError as ex: | |
if i == 1: | |
# Only display the warning message | |
# on the first iteration, no need to do that | |
# every iteration | |
warnings.warn( | |
"Some layers could not be computed because the size of the " | |
"image is not big enough. It is fine, as long as the non" | |
"computed layers are not used in the objective function" | |
f"(exception details: '{ex}')" | |
) | |
if args.disc: | |
'''dom loss part''' | |
# content_img = raw_img | |
# style_img = raw_img | |
# precpt_loss = run_precpt(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, transform_f(image_f())) | |
for p in netD.parameters(): | |
p.requires_grad = True | |
for _ in range(args.drec): | |
netD.zero_grad() | |
real = real_img | |
fake = image_f() | |
# for _ in range(6): | |
# errD, D_x, D_G_z1 = update_d(args, netD, optD, real, fake) | |
# label = torch.full((args.b,), 1., dtype=torch.float, device=device) | |
# label.fill_(1.) | |
# output = netD(fake).view(-1) | |
# errG = nn.BCELoss()(output, label) | |
# D_G_z2 = output.mean().item() | |
# dom_loss = err | |
one = torch.tensor(1, dtype=torch.float) | |
mone = one * -1 | |
one = one.cuda(args.gpu_device) | |
mone = mone.cuda(args.gpu_device) | |
d_loss_real = netD(real) | |
d_loss_real = d_loss_real.mean() | |
d_loss_real.backward(mone) | |
d_loss_fake = netD(fake) | |
d_loss_fake = d_loss_fake.mean() | |
d_loss_fake.backward(one) | |
# Train with gradient penalty | |
gradient_penalty = calculate_gradient_penalty(netD, real.data, fake.data) | |
gradient_penalty.backward() | |
d_loss = d_loss_fake - d_loss_real + gradient_penalty | |
Wasserstein_D = d_loss_real - d_loss_fake | |
optD.step() | |
# Generator update | |
for p in netD.parameters(): | |
p.requires_grad = False # to avoid computation | |
fake_images = image_f() | |
g_loss = netD(fake_images) | |
g_loss = -g_loss.mean() | |
dom_loss = g_loss | |
g_cost = -g_loss | |
if i% 5 == 0: | |
print(f' loss_fake: {d_loss_fake}, loss_real: {d_loss_real}') | |
print(f'Generator g_loss: {g_loss}') | |
'''end''' | |
'''ssim loss''' | |
'''end''' | |
if args.disc: | |
loss = sign * objective_f(hook) + args.pw * dom_loss | |
# loss = args.pw * dom_loss | |
else: | |
loss = sign * objective_f(hook) | |
# loss = args.pw * dom_loss | |
loss.backward() | |
# #video the images | |
# if i % 5 == 0: | |
# print('1') | |
# image_name = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png' | |
# img_path = os.path.join(args.path_helper['sample_path'], str(image_name)) | |
# export(image_f(), img_path) | |
# #end | |
# if i % 50 == 0: | |
# print('Loss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' | |
# % (errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) | |
optimizer.step() | |
if i in thresholds: | |
image = tensor_to_img_array(image_f()) | |
# if verbose: | |
# print("Loss at step {}: {:.3f}".format(i, objective_f(hook))) | |
if save_image: | |
na = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png' | |
na = date_time + na | |
outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path'] | |
img_path = os.path.join(outpath, str(na)) | |
export(image_f(), img_path) | |
images.append(image) | |
except KeyboardInterrupt: | |
print("Interrupted optimization at step {:d}.".format(i)) | |
if verbose: | |
print("Loss at step {}: {:.3f}".format(i, objective_f(hook))) | |
images.append(tensor_to_img_array(image_f())) | |
if save_image: | |
na = image_name[0].split('\\')[-1].split('.')[0] + '.png' | |
na = date_time + na | |
outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path'] | |
img_path = os.path.join(outpath, str(na)) | |
export(image_f(), img_path) | |
if show_inline: | |
show(tensor_to_img_array(image_f())) | |
elif show_image: | |
view(image_f()) | |
return image_f() | |
def tensor_to_img_array(tensor): | |
image = tensor.cpu().detach().numpy() | |
image = np.transpose(image, [0, 2, 3, 1]) | |
return image | |
def view(tensor): | |
image = tensor_to_img_array(tensor) | |
assert len(image.shape) in [ | |
3, | |
4, | |
], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape) | |
# Change dtype for PIL.Image | |
image = (image * 255).astype(np.uint8) | |
if len(image.shape) == 4: | |
image = np.concatenate(image, axis=1) | |
Image.fromarray(image).show() | |
def export(tensor, img_path=None): | |
# image_name = image_name or "image.jpg" | |
c = tensor.size(1) | |
# if c == 7: | |
# for i in range(c): | |
# w_map = tensor[:,i,:,:].unsqueeze(1) | |
# w_map = tensor_to_img_array(w_map).squeeze() | |
# w_map = (w_map * 255).astype(np.uint8) | |
# image_name = image_name[0].split('/')[-1].split('.')[0] + str(i)+ '.png' | |
# wheat = sns.heatmap(w_map,cmap='coolwarm') | |
# figure = wheat.get_figure() | |
# figure.savefig ('./fft_maps/weightheatmap/'+str(image_name), dpi=400) | |
# figure = 0 | |
# else: | |
if c == 3: | |
vutils.save_image(tensor, fp = img_path) | |
else: | |
image = tensor[:,0:3,:,:] | |
w_map = tensor[:,-1,:,:].unsqueeze(1) | |
image = tensor_to_img_array(image) | |
w_map = 1 - tensor_to_img_array(w_map).squeeze() | |
# w_map[w_map==1] = 0 | |
assert len(image.shape) in [ | |
3, | |
4, | |
], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape) | |
# Change dtype for PIL.Image | |
image = (image * 255).astype(np.uint8) | |
w_map = (w_map * 255).astype(np.uint8) | |
Image.fromarray(w_map,'L').save(img_path) | |
class ModuleHook: | |
def __init__(self, module): | |
self.hook = module.register_forward_hook(self.hook_fn) | |
self.module = None | |
self.features = None | |
def hook_fn(self, module, input, output): | |
self.module = module | |
self.features = output | |
def close(self): | |
self.hook.remove() | |
def hook_model(model, image_f): | |
features = OrderedDict() | |
# recursive hooking function | |
def hook_layers(net, prefix=[]): | |
if hasattr(net, "_modules"): | |
for name, layer in net._modules.items(): | |
if layer is None: | |
# e.g. GoogLeNet's aux1 and aux2 layers | |
continue | |
features["_".join(prefix + [name])] = ModuleHook(layer) | |
hook_layers(layer, prefix=prefix + [name]) | |
hook_layers(model) | |
def hook(layer): | |
if layer == "input": | |
out = image_f() | |
elif layer == "labels": | |
out = list(features.values())[-1].features | |
else: | |
assert layer in features, f"Invalid layer {layer}. Retrieve the list of layers with `lucent.modelzoo.util.get_model_layers(model)`." | |
out = features[layer].features | |
assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example." | |
return out | |
return hook | |
def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None, boxes = None): | |
b,c,h,w = pred_masks.size() | |
dev = pred_masks.get_device() | |
row_num = min(b, 4) | |
if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0: | |
pred_masks = torch.sigmoid(pred_masks) | |
if reverse == True: | |
pred_masks = 1 - pred_masks | |
gt_masks = 1 - gt_masks | |
if c == 2: # for REFUGE multi mask output | |
pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) | |
gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) | |
tup = (imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]) | |
compose = torch.cat(tup, 0) | |
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) | |
elif c > 2: # for multi-class segmentation > 2 classes | |
preds = [] | |
gts = [] | |
for i in range(0, c): | |
pred = pred_masks[:,i,:,:].unsqueeze(1).expand(b,3,h,w) | |
preds.append(pred) | |
gt = gt_masks[:,i,:,:].unsqueeze(1).expand(b,3,h,w) | |
gts.append(gt) | |
tup = [imgs[:row_num,:,:,:]] + preds + gts | |
compose = torch.cat(tup,0) | |
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) | |
else: | |
imgs = torchvision.transforms.Resize((h,w))(imgs) | |
if imgs.size(1) == 1: | |
imgs = imgs[:,0,:,:].unsqueeze(1).expand(b,3,h,w) | |
pred_masks = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w) | |
gt_masks = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w) | |
if points != None: | |
for i in range(b): | |
if args.thd: | |
ps = np.round(points.cpu()/args.roi_size * args.out_size).to(dtype = torch.int) | |
else: | |
ps = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int) | |
# gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev))) | |
for p in ps: | |
gt_masks[i,0,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.5 | |
gt_masks[i,1,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.1 | |
gt_masks[i,2,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.4 | |
if boxes is not None: | |
for i in range(b): | |
# the next line causes: ValueError: Tensor uint8 expected, got torch.float32 | |
# imgs[i, :] = torchvision.utils.draw_bounding_boxes(imgs[i, :], boxes[i]) | |
# until TorchVision 0.19 is released (paired with Pytorch 2.4), apply this workaround: | |
img255 = (imgs[i] * 255).byte() | |
img255 = torchvision.utils.draw_bounding_boxes(img255, boxes[i].reshape(-1, 4), colors="red") | |
img01 = img255 / 255 | |
# torchvision.utils.save_image(img01, save_path + "_boxes.png") | |
imgs[i, :] = img01 | |
tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:]) | |
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) | |
compose = torch.cat(tup,0) | |
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) | |
return | |
def eval_seg(pred,true_mask_p,threshold): | |
''' | |
threshold: a int or a tuple of int | |
masks: [b,2,h,w] | |
pred: [b,2,h,w] | |
''' | |
b, c, h, w = pred.size() | |
if c == 2: | |
iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0 | |
for th in threshold: | |
gt_vmask_p = (true_mask_p > th).float() | |
vpred = (pred > th).float() | |
vpred_cpu = vpred.cpu() | |
disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32') | |
cup_pred = vpred_cpu[:,1,:,:].numpy().astype('int32') | |
disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32') | |
cup_mask = gt_vmask_p [:, 1, :, :].squeeze(1).cpu().numpy().astype('int32') | |
'''iou for numpy''' | |
iou_d += iou(disc_pred,disc_mask) | |
iou_c += iou(cup_pred,cup_mask) | |
'''dice for torch''' | |
disc_dice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item() | |
cup_dice += dice_coeff(vpred[:,1,:,:], gt_vmask_p[:,1,:,:]).item() | |
return iou_d / len(threshold), iou_c / len(threshold), disc_dice / len(threshold), cup_dice / len(threshold) | |
elif c > 2: # for multi-class segmentation > 2 classes | |
ious = [0] * c | |
dices = [0] * c | |
for th in threshold: | |
gt_vmask_p = (true_mask_p > th).float() | |
vpred = (pred > th).float() | |
vpred_cpu = vpred.cpu() | |
for i in range(0, c): | |
pred = vpred_cpu[:,i,:,:].numpy().astype('int32') | |
mask = gt_vmask_p[:,i,:,:].squeeze(1).cpu().numpy().astype('int32') | |
'''iou for numpy''' | |
ious[i] += iou(pred,mask) | |
'''dice for torch''' | |
dices[i] += dice_coeff(vpred[:,i,:,:], gt_vmask_p[:,i,:,:]).item() | |
return tuple(np.array(ious + dices) / len(threshold)) # tuple has a total number of c * 2 | |
else: | |
eiou, edice = 0,0 | |
for th in threshold: | |
gt_vmask_p = (true_mask_p > th).float() | |
vpred = (pred > th).float() | |
vpred_cpu = vpred.cpu() | |
disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32') | |
disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32') | |
'''iou for numpy''' | |
eiou += iou(disc_pred,disc_mask) | |
'''dice for torch''' | |
edice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item() | |
return eiou / len(threshold), edice / len(threshold) | |
# @objectives.wrap_objective() | |
def dot_compare(layer, batch=1, cossim_pow=0): | |
def inner(T): | |
dot = (T(layer)[batch] * T(layer)[0]).sum() | |
mag = torch.sqrt(torch.sum(T(layer)[0]**2)) | |
cossim = dot/(1e-6 + mag) | |
return -dot * cossim ** cossim_pow | |
return inner | |
def init_D(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
nn.init.normal_(m.weight.data, 1.0, 0.02) | |
nn.init.constant_(m.bias.data, 0) | |
def pre_d(): | |
netD = Discriminator(3).to(device) | |
# netD.apply(init_D) | |
beta1 = 0.5 | |
dis_lr = 0.00002 | |
optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) | |
return netD, optimizerD | |
def update_d(args, netD, optimizerD, real, fake): | |
criterion = nn.BCELoss() | |
label = torch.full((args.b,), 1., dtype=torch.float, device=device) | |
output = netD(real).view(-1) | |
# Calculate loss on all-real batch | |
errD_real = criterion(output, label) | |
# Calculate gradients for D in backward pass | |
errD_real.backward() | |
D_x = output.mean().item() | |
label.fill_(0.) | |
# Classify all fake batch with D | |
output = netD(fake.detach()).view(-1) | |
# Calculate D's loss on the all-fake batch | |
errD_fake = criterion(output, label) | |
# Calculate the gradients for this batch, accumulated (summed) with previous gradients | |
errD_fake.backward() | |
D_G_z1 = output.mean().item() | |
# Compute error of D as sum over the fake and the real batches | |
errD = errD_real + errD_fake | |
# Update D | |
optimizerD.step() | |
return errD, D_x, D_G_z1 | |
def calculate_gradient_penalty(netD, real_images, fake_images): | |
eta = torch.FloatTensor(args.b,1,1,1).uniform_(0,1) | |
eta = eta.expand(args.b, real_images.size(1), real_images.size(2), real_images.size(3)).to(device = device) | |
interpolated = (eta * real_images + ((1 - eta) * fake_images)).to(device = device) | |
# define it to calculate gradient | |
interpolated = Variable(interpolated, requires_grad=True) | |
# calculate probability of interpolated examples | |
prob_interpolated = netD(interpolated) | |
# calculate gradients of probabilities with respect to examples | |
gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated, | |
grad_outputs=torch.ones( | |
prob_interpolated.size()).to(device = device), | |
create_graph=True, retain_graph=True)[0] | |
grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 | |
return grad_penalty | |
def random_click(mask, point_labels = 1): | |
# check if all masks are black | |
max_label = max(set(mask.flatten())) | |
if max_label == 0: | |
point_labels = max_label | |
# max agreement position | |
indices = np.argwhere(mask == max_label) | |
return point_labels, indices[np.random.randint(len(indices))] | |
def generate_click_prompt(img, msk, pt_label = 1): | |
# return: prompt, prompt mask | |
pt_list = [] | |
msk_list = [] | |
b, c, h, w, d = msk.size() | |
msk = msk[:,0,:,:,:] | |
for i in range(d): | |
pt_list_s = [] | |
msk_list_s = [] | |
for j in range(b): | |
msk_s = msk[j,:,:,i] | |
indices = torch.nonzero(msk_s) | |
if indices.size(0) == 0: | |
# generate a random array between [0-h, 0-h]: | |
random_index = torch.randint(0, h, (2,)).to(device = msk.device) | |
new_s = msk_s | |
else: | |
random_index = random.choice(indices) | |
label = msk_s[random_index[0], random_index[1]] | |
new_s = torch.zeros_like(msk_s) | |
# convert bool tensor to int | |
new_s = (msk_s == label).to(dtype = torch.float) | |
# new_s[msk_s == label] = 1 | |
pt_list_s.append(random_index) | |
msk_list_s.append(new_s) | |
pts = torch.stack(pt_list_s, dim=0) | |
msks = torch.stack(msk_list_s, dim=0) | |
pt_list.append(pts) | |
msk_list.append(msks) | |
pt = torch.stack(pt_list, dim=-1) | |
msk = torch.stack(msk_list, dim=-1) | |
msk = msk.unsqueeze(1) | |
return img, pt, msk #[b, 2, d], [b, c, h, w, d] | |
def random_box(multi_rater): | |
max_value = torch.max(multi_rater[:,0,:,:], dim=0)[0] | |
max_value_position = torch.nonzero(max_value) | |
x_coords = max_value_position[:, 0] | |
y_coords = max_value_position[:, 1] | |
x_min = int(torch.min(x_coords)) | |
x_max = int(torch.max(x_coords)) | |
y_min = int(torch.min(y_coords)) | |
y_max = int(torch.max(y_coords)) | |
x_min = random.choice(np.arange(x_min-10,x_min+11)) | |
x_max = random.choice(np.arange(x_max-10,x_max+11)) | |
y_min = random.choice(np.arange(y_min-10,y_min+11)) | |
y_max = random.choice(np.arange(y_max-10,y_max+11)) | |
return x_min, x_max, y_min, y_max | |