Spaces:
Running
on
L40S
Running
on
L40S
from abc import ABC, abstractmethod | |
import torch | |
class BaseRewardLoss(ABC): | |
""" | |
Base class for reward functions implementing a differentiable reward function for optimization. | |
""" | |
def __init__(self, name: str, weighting: float): | |
self.name = name | |
self.weighting = weighting | |
def freeze_parameters(params: torch.nn.ParameterList): | |
for param in params: | |
param.requires_grad = False | |
def get_image_features(self, image: torch.Tensor) -> torch.Tensor: | |
pass | |
def get_text_features(self, prompt: str) -> torch.Tensor: | |
pass | |
def compute_loss( | |
self, image_features: torch.Tensor, text_features: torch.Tensor | |
) -> torch.Tensor: | |
pass | |
def process_features(self, features: torch.Tensor) -> torch.Tensor: | |
features_normed = features / features.norm(dim=-1, keepdim=True) | |
return features_normed | |
def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor: | |
image_features = self.get_image_features(image) | |
text_features = self.get_text_features(prompt) | |
image_features_normed = self.process_features(image_features) | |
text_features_normed = self.process_features(text_features) | |
loss = self.compute_loss(image_features_normed, text_features_normed) | |
return loss | |