Jinkin's picture
Upload utils.py
cf5524e
raw history blame
No virus
1.66 kB
import torch
import logging
from torch import Tensor
from typing import Mapping
def _setup_logger():
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format)
logger.handlers = [console_handler]
return logger
logger = _setup_logger()
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda(non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return tuple([_move_to_cuda(x) for x in maybe_tensor])
elif isinstance(maybe_tensor, Mapping):
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
else:
return maybe_tensor
return _move_to_cuda(sample)
def pool(last_hidden_states: Tensor,
attention_mask: Tensor,
pool_type: str) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
if pool_type == "avg":
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pool_type == "cls":
emb = last_hidden[:, 0]
else:
raise ValueError(f"pool_type {pool_type} not supported")
return emb