import torch from torchvision import transforms from math import pi import torchvision.transforms.functional as TF # Define helper functions def exists(val): """Check if a variable exists""" return val is not None def uniq(arr): return {el: True for el in arr}.keys() def default(val, d): """If a value exists, return it; otherwise, return a default value""" return val if exists(val) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def cast_tuple(val, depth=1): if isinstance(val, list): val = tuple(val) return val if isinstance(val, tuple) else (val,) * depth def is_empty(t): """Check if a tensor is empty""" # Return True if the number of elements in the tensor is zero, else False return t.nelement() == 0 def masked_mean(t, mask, dim=1): """ Compute the mean of a tensor, masked by a given mask Args: t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim) mask (torch.Tensor): mask tensor of shape (batch_size, seq_len) dim (int): dimension along which to compute the mean (default=1) Returns: torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim) """ t = t.masked_fill(~mask[:, :, None], 0.0) return t.sum(dim=1) / mask.sum(dim=1)[..., None] def set_requires_grad(model, value): """ Set whether or not the model's parameters require gradients Args: model (torch.nn.Module): the PyTorch model to modify value (bool): whether or not to require gradients """ for param in model.parameters(): param.requires_grad = value def eval_decorator(fn): """ Decorator function to evaluate a given function Args: fn (callable): function to evaluate Returns: callable: the decorated function """ def inner(model, *args, **kwargs): was_training = model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def log(t, eps=1e-20): """ Compute the natural logarithm of a tensor Args: t (torch.Tensor): input tensor eps (float): small value to add to prevent taking the log of 0 (default=1e-20) Returns: torch.Tensor: the natural logarithm of the input tensor """ return torch.log(t + eps) def gumbel_noise(t): """ Generate Gumbel noise Args: t (torch.Tensor): input tensor Returns: torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor """ noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature=0.9, dim=-1): """ Sample from a Gumbel-softmax distribution Args: t (torch.Tensor): input tensor of shape (batch_size, num_classes) temperature (float): temperature for the Gumbel-softmax distribution (default=0.9) dim (int): dimension along which to sample (default=-1) Returns: torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor """ return (t / max(temperature, 1e-10)) + gumbel_noise(t) def top_k(logits, thres=0.5): """ Return a tensor where all but the top k values are set to negative infinity Args: logits (torch.Tensor): input tensor of shape (batch_size, num_classes) thres (float): threshold for the top k values (default=0.5) Returns: torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity """ num_logits = logits.shape[-1] k = max(int((1 - thres) * num_logits), 1) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float("-inf")) probs.scatter_(-1, ind, val) return probs def gamma_func(mode="cosine", scale=0.15): """Return a function that takes a single input r and returns a value based on the selected mode""" # Define a different function based on the selected mode if mode == "linear": return lambda r: 1 - r elif mode == "cosine": return lambda r: torch.cos(r * pi / 2) elif mode == "square": return lambda r: 1 - r**2 elif mode == "cubic": return lambda r: 1 - r**3 elif mode == "scaled-cosine": return lambda r: scale * (torch.cos(r * pi / 2)) else: # Raise an error if the selected mode is not implemented raise NotImplementedError class always: """Helper class to always return a given value""" def __init__(self, val): self.val = val def __call__(self, x, *args, **kwargs): return self.val class DivideMax(torch.nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): maxes = x.amax(dim=self.dim, keepdim=True).detach() return x / maxes def replace_outliers(image, percentile=0.0001): lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile( image, 1 - percentile ) mask = (image <= upper_bound) & (image >= lower_bound) valid_pixels = image[mask] image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels)) return image def process_image(image, dataset, image_type=None): image = TF.to_tensor(image)[0].unsqueeze(0).unsqueeze(0) image /= image.max() if dataset == "HPA": if image_type == 'nucleus': normalize = (0.0655, 0.0650) elif image_type == 'protein': normalize = (0.1732, 0.1208) elif dataset == "OpenCell": if image_type == 'nucleus': normalize = (0.0272, 0.0244) elif image_type == 'protein': normalize = (0.0486, 0.0671) t_forms = [] t_forms.append(transforms.RandomCrop(256)) # t_forms.append(transforms.Normalize(normalize[0],normalize[1])) image = transforms.Compose(t_forms)(image) return image