import torch from torch import nn from transformers import ResNetPreTrainedModel from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention from transformers.image_processing_utils import BaseImageProcessor from transformers import ResNetConfig, ResNetModel from typing import Optional class ResNetForZeroBitWatermarkDetection(ResNetPreTrainedModel): def __init__(self, config): super().__init__(config) self.resnet = ResNetModel(config) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(config.hidden_sizes[-1], 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 1)) self.register_buffer('exp', torch.tensor([1.0])) # initialize weights and apply final processing self.post_init() # TODO docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> ImageClassifierOutputWithNoAttention: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) pooled_output = outputs.pooler_output if return_dict else outputs[1] x = self.classifier(pooled_output) # generalized-Gaussian recalibration, centering and scaling is already included in last linear layer x = 0.5 + torch.sign(x) * 0.5 * torch.special.gammainc(1 / self.exp, torch.abs(x)**self.exp) # Laplacian calibration, centering and scaling is already included in last linear layer # if exp==1 #x = 0.5 + torch.sign(x) * 0.5 * (1 - torch.exp(-torch.abs(x))) # laplacian logits = torch.log(x) - torch.log1p(-x) loss = None if not return_dict: output = (logits,) + outputs[2:] return (loss,) + output if loss is not None else output return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)