import torch import logging from torch import Tensor from typing import Mapping def _setup_logger(): log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") logger = logging.getLogger() logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() console_handler.setFormatter(log_format) logger.handlers = [console_handler] return logger logger = _setup_logger() def move_to_cuda(sample): if len(sample) == 0: return {} def _move_to_cuda(maybe_tensor): if torch.is_tensor(maybe_tensor): return maybe_tensor.cuda(non_blocking=True) elif isinstance(maybe_tensor, dict): return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} elif isinstance(maybe_tensor, list): return [_move_to_cuda(x) for x in maybe_tensor] elif isinstance(maybe_tensor, tuple): return tuple([_move_to_cuda(x) for x in maybe_tensor]) elif isinstance(maybe_tensor, Mapping): return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) else: return maybe_tensor return _move_to_cuda(sample) def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) if pool_type == "avg": emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] elif pool_type == "cls": emb = last_hidden[:, 0] else: raise ValueError(f"pool_type {pool_type} not supported") return emb