import cv2 import torchvision.transforms as transforms from scipy.ndimage import gaussian_filter from loss import FocalLoss, BinaryDiceLoss from tools import visualization, calculate_metric, calculate_average_metric from .adaclip import * from .custom_clip import create_model_and_transforms class AdaCLIP_Trainer(nn.Module): def __init__( self, # clip-related backbone, feat_list, input_dim, output_dim, # learning-related learning_rate, device, image_size, # model settings prompting_depth=3, prompting_length=2, prompting_branch='VL', prompting_type='SD', use_hsf=True, k_clusters=20, ): super(AdaCLIP_Trainer, self).__init__() self.device = device self.feat_list = feat_list self.image_size = image_size self.prompting_branch = prompting_branch self.prompting_type = prompting_type self.loss_focal = FocalLoss() self.loss_dice = BinaryDiceLoss() ########### different model choices freeze_clip, _, self.preprocess = create_model_and_transforms(backbone, image_size, pretrained='openai') freeze_clip = freeze_clip.to(device) freeze_clip.eval() self.clip_model = AdaCLIP(freeze_clip=freeze_clip, text_channel=output_dim, visual_channel=input_dim, prompting_length=prompting_length, prompting_depth=prompting_depth, prompting_branch=prompting_branch, prompting_type=prompting_type, use_hsf=use_hsf, k_clusters=k_clusters, output_layers=feat_list, device=device, image_size=image_size).to(device) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.CenterCrop(image_size), transforms.ToTensor() ]) self.preprocess.transforms[0] = transforms.Resize(size=(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC, max_size=None) self.preprocess.transforms[1] = transforms.CenterCrop(size=(image_size, image_size)) # update parameters self.learnable_paramter_list = [ 'text_prompter', 'visual_prompter', 'patch_token_layer', 'cls_token_layer', 'dynamic_visual_prompt_generator', 'dynamic_text_prompt_generator' ] self.params_to_update = [] for name, param in self.clip_model.named_parameters(): # print(name) for update_name in self.learnable_paramter_list: if update_name in name: # print(f'updated parameters--{name}: {update_name}') self.params_to_update.append(param) # build the optimizer self.optimizer = torch.optim.AdamW(self.params_to_update, lr=learning_rate, betas=(0.5, 0.999)) def save(self, path): self.save_dict = {} for param, value in self.state_dict().items(): for update_name in self.learnable_paramter_list: if update_name in param: # print(f'{param}: {update_name}') self.save_dict[param] = value break torch.save(self.save_dict, path) def load(self, path): self.load_state_dict(torch.load(path, map_location=self.device), strict=False) def train_one_batch(self, items): image = items['img'].to(self.device) cls_name = items['cls_name'] # pixel level anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False) if not isinstance(anomaly_map, list): anomaly_map = [anomaly_map] # losses gt = items['img_mask'].to(self.device) gt = gt.squeeze() gt[gt > 0.5] = 1 gt[gt <= 0.5] = 0 is_anomaly = items['anomaly'].to(self.device) is_anomaly[is_anomaly > 0.5] = 1 is_anomaly[is_anomaly <= 0.5] = 0 loss = 0 # classification loss classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1)) loss += classification_loss # seg loss seg_loss = 0 for am, in zip(anomaly_map): seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) + self.loss_dice(am[:, 0, :, :], 1-gt)) loss += seg_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss def train_epoch(self, loader): self.clip_model.train() loss_list = [] for items in loader: loss = self.train_one_batch(items) loss_list.append(loss.item()) return np.mean(loss_list) @torch.no_grad() def evaluation(self, dataloader, obj_list, save_fig, save_fig_dir=None): self.clip_model.eval() results = {} results['cls_names'] = [] results['imgs_gts'] = [] results['anomaly_scores'] = [] results['imgs_masks'] = [] results['anomaly_maps'] = [] results['imgs'] = [] results['names'] = [] with torch.no_grad(), torch.cuda.amp.autocast(): image_indx = 0 for indx, items in enumerate(dataloader): if save_fig: path = items['img_path'] for _path in path: vis_image = cv2.resize(cv2.imread(_path), (self.image_size, self.image_size)) results['imgs'].append(vis_image) cls_name = items['cls_name'] for _cls_name in cls_name: image_indx += 1 results['names'].append('{:}-{:03d}'.format(_cls_name, image_indx)) image = items['img'].to(self.device) cls_name = items['cls_name'] results['cls_names'].extend(cls_name) gt_mask = items['img_mask'] gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 for _gt_mask in gt_mask: results['imgs_masks'].append(_gt_mask.squeeze(0).numpy()) # px # pixel level anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=True) anomaly_map = anomaly_map.cpu().numpy() anomaly_score = anomaly_score.cpu().numpy() for _anomaly_map, _anomaly_score in zip(anomaly_map, anomaly_score): _anomaly_map = gaussian_filter(_anomaly_map, sigma=4) results['anomaly_maps'].append(_anomaly_map) results['anomaly_scores'].append(_anomaly_score) is_anomaly = np.array(items['anomaly']) for _is_anomaly in is_anomaly: results['imgs_gts'].append(_is_anomaly) # visualization if save_fig: print('saving fig.....') visualization.plot_sample_cv2( results['names'], results['imgs'], {'AdaCLIP': results['anomaly_maps']}, results['imgs_masks'], save_fig_dir ) metric_dict = dict() for obj in obj_list: metric_dict[obj] = dict() for obj in obj_list: metric = calculate_metric(results, obj) obj_full_name = f'{obj}' metric_dict[obj_full_name] = metric metric_dict['Average'] = calculate_average_metric(metric_dict) return metric_dict