Spaces:
Running
Running
import cv2 | |
import torchvision.transforms as transforms | |
from scipy.ndimage import gaussian_filter | |
from loss import FocalLoss, BinaryDiceLoss | |
from tools import visualization, calculate_metric, calculate_average_metric | |
from .adaclip import * | |
from .custom_clip import create_model_and_transforms | |
class AdaCLIP_Trainer(nn.Module): | |
def __init__( | |
self, | |
# clip-related | |
backbone, feat_list, input_dim, output_dim, | |
# learning-related | |
learning_rate, device, image_size, | |
# model settings | |
prompting_depth=3, prompting_length=2, | |
prompting_branch='VL', prompting_type='SD', | |
use_hsf=True, k_clusters=20, | |
): | |
super(AdaCLIP_Trainer, self).__init__() | |
self.device = device | |
self.feat_list = feat_list | |
self.image_size = image_size | |
self.prompting_branch = prompting_branch | |
self.prompting_type = prompting_type | |
self.loss_focal = FocalLoss() | |
self.loss_dice = BinaryDiceLoss() | |
########### different model choices | |
freeze_clip, _, self.preprocess = create_model_and_transforms(backbone, image_size, | |
pretrained='openai') | |
freeze_clip = freeze_clip.to(device) | |
freeze_clip.eval() | |
self.clip_model = AdaCLIP(freeze_clip=freeze_clip, | |
text_channel=output_dim, | |
visual_channel=input_dim, | |
prompting_length=prompting_length, | |
prompting_depth=prompting_depth, | |
prompting_branch=prompting_branch, | |
prompting_type=prompting_type, | |
use_hsf=use_hsf, | |
k_clusters=k_clusters, | |
output_layers=feat_list, | |
device=device, | |
image_size=image_size).to(device) | |
self.transform = transforms.Compose([ | |
transforms.Resize((image_size, image_size)), | |
transforms.CenterCrop(image_size), | |
transforms.ToTensor() | |
]) | |
self.preprocess.transforms[0] = transforms.Resize(size=(image_size, image_size), | |
interpolation=transforms.InterpolationMode.BICUBIC, | |
max_size=None) | |
self.preprocess.transforms[1] = transforms.CenterCrop(size=(image_size, image_size)) | |
# update parameters | |
self.learnable_paramter_list = [ | |
'text_prompter', | |
'visual_prompter', | |
'patch_token_layer', | |
'cls_token_layer', | |
'dynamic_visual_prompt_generator', | |
'dynamic_text_prompt_generator' | |
] | |
self.params_to_update = [] | |
for name, param in self.clip_model.named_parameters(): | |
# print(name) | |
for update_name in self.learnable_paramter_list: | |
if update_name in name: | |
# print(f'updated parameters--{name}: {update_name}') | |
self.params_to_update.append(param) | |
# build the optimizer | |
self.optimizer = torch.optim.AdamW(self.params_to_update, lr=learning_rate, betas=(0.5, 0.999)) | |
def save(self, path): | |
self.save_dict = {} | |
for param, value in self.state_dict().items(): | |
for update_name in self.learnable_paramter_list: | |
if update_name in param: | |
# print(f'{param}: {update_name}') | |
self.save_dict[param] = value | |
break | |
torch.save(self.save_dict, path) | |
def load(self, path): | |
self.load_state_dict(torch.load(path, map_location=self.device), strict=False) | |
def train_one_batch(self, items): | |
image = items['img'].to(self.device) | |
cls_name = items['cls_name'] | |
# pixel level | |
anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False) | |
if not isinstance(anomaly_map, list): | |
anomaly_map = [anomaly_map] | |
# losses | |
gt = items['img_mask'].to(self.device) | |
gt = gt.squeeze() | |
gt[gt > 0.5] = 1 | |
gt[gt <= 0.5] = 0 | |
is_anomaly = items['anomaly'].to(self.device) | |
is_anomaly[is_anomaly > 0.5] = 1 | |
is_anomaly[is_anomaly <= 0.5] = 0 | |
loss = 0 | |
# classification loss | |
classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1)) | |
loss += classification_loss | |
# seg loss | |
seg_loss = 0 | |
for am, in zip(anomaly_map): | |
seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) + | |
self.loss_dice(am[:, 0, :, :], 1-gt)) | |
loss += seg_loss | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
return loss | |
def train_epoch(self, loader): | |
self.clip_model.train() | |
loss_list = [] | |
for items in loader: | |
loss = self.train_one_batch(items) | |
loss_list.append(loss.item()) | |
return np.mean(loss_list) | |
def evaluation(self, dataloader, obj_list, save_fig, save_fig_dir=None): | |
self.clip_model.eval() | |
results = {} | |
results['cls_names'] = [] | |
results['imgs_gts'] = [] | |
results['anomaly_scores'] = [] | |
results['imgs_masks'] = [] | |
results['anomaly_maps'] = [] | |
results['imgs'] = [] | |
results['names'] = [] | |
with torch.no_grad(), torch.cuda.amp.autocast(): | |
image_indx = 0 | |
for indx, items in enumerate(dataloader): | |
if save_fig: | |
path = items['img_path'] | |
for _path in path: | |
vis_image = cv2.resize(cv2.imread(_path), (self.image_size, self.image_size)) | |
results['imgs'].append(vis_image) | |
cls_name = items['cls_name'] | |
for _cls_name in cls_name: | |
image_indx += 1 | |
results['names'].append('{:}-{:03d}'.format(_cls_name, image_indx)) | |
image = items['img'].to(self.device) | |
cls_name = items['cls_name'] | |
results['cls_names'].extend(cls_name) | |
gt_mask = items['img_mask'] | |
gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 | |
for _gt_mask in gt_mask: | |
results['imgs_masks'].append(_gt_mask.squeeze(0).numpy()) # px | |
# pixel level | |
anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=True) | |
anomaly_map = anomaly_map.cpu().numpy() | |
anomaly_score = anomaly_score.cpu().numpy() | |
for _anomaly_map, _anomaly_score in zip(anomaly_map, anomaly_score): | |
_anomaly_map = gaussian_filter(_anomaly_map, sigma=4) | |
results['anomaly_maps'].append(_anomaly_map) | |
results['anomaly_scores'].append(_anomaly_score) | |
is_anomaly = np.array(items['anomaly']) | |
for _is_anomaly in is_anomaly: | |
results['imgs_gts'].append(_is_anomaly) | |
# visualization | |
if save_fig: | |
print('saving fig.....') | |
visualization.plot_sample_cv2( | |
results['names'], | |
results['imgs'], | |
{'AdaCLIP': results['anomaly_maps']}, | |
results['imgs_masks'], | |
save_fig_dir | |
) | |
metric_dict = dict() | |
for obj in obj_list: | |
metric_dict[obj] = dict() | |
for obj in obj_list: | |
metric = calculate_metric(results, obj) | |
obj_full_name = f'{obj}' | |
metric_dict[obj_full_name] = metric | |
metric_dict['Average'] = calculate_average_metric(metric_dict) | |
return metric_dict | |