File size: 19,999 Bytes
626eca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 |
import contextlib
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
import transformers as tr
from relik.common.log import get_console_logger, get_logger
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
console_logger = get_console_logger()
logger = get_logger(__name__, level=logging.INFO)
@dataclass
class GoldenRetrieverOutput(tr.file_utils.ModelOutput):
"""Class for model's outputs."""
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
question_encodings: Optional[torch.FloatTensor] = None
passages_encodings: Optional[torch.FloatTensor] = None
class GoldenRetriever(torch.nn.Module):
def __init__(
self,
question_encoder: Union[str, tr.PreTrainedModel],
loss_type: Optional[torch.nn.Module] = None,
passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None,
document_index: Optional[Union[str, BaseDocumentIndex]] = None,
question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
device: Optional[Union[str, torch.device]] = None,
precision: Optional[Union[str, int]] = None,
index_precision: Optional[Union[str, int]] = 32,
index_device: Optional[Union[str, torch.device]] = "cpu",
*args,
**kwargs,
):
super().__init__()
self.passage_encoder_is_question_encoder = False
# question encoder model
if isinstance(question_encoder, str):
question_encoder = GoldenRetrieverModel.from_pretrained(
question_encoder, **kwargs
)
self.question_encoder = question_encoder
if passage_encoder is None:
# if no passage encoder is provided,
# share the weights of the question encoder
passage_encoder = question_encoder
# keep track of the fact that the passage encoder is the same as the question encoder
self.passage_encoder_is_question_encoder = True
if isinstance(passage_encoder, str):
passage_encoder = GoldenRetrieverModel.from_pretrained(
passage_encoder, **kwargs
)
# passage encoder model
self.passage_encoder = passage_encoder
# loss function
self.loss_type = loss_type
# indexer stuff
if document_index is None:
# if no indexer is provided, create a new one
document_index = InMemoryDocumentIndex(
device=index_device, precision=index_precision, **kwargs
)
if isinstance(document_index, str):
document_index = BaseDocumentIndex.from_pretrained(
document_index, device=index_device, precision=index_precision, **kwargs
)
self.document_index = document_index
# lazy load the tokenizer for inference
self._question_tokenizer = question_tokenizer
self._passage_tokenizer = passage_tokenizer
# move the model to the device
self.to(device or torch.device("cpu"))
# set the precision
self.precision = precision
def forward(
self,
questions: Optional[Dict[str, torch.Tensor]] = None,
passages: Optional[Dict[str, torch.Tensor]] = None,
labels: Optional[torch.Tensor] = None,
question_encodings: Optional[torch.Tensor] = None,
passages_encodings: Optional[torch.Tensor] = None,
passages_per_question: Optional[List[int]] = None,
return_loss: bool = False,
return_encodings: bool = False,
*args,
**kwargs,
) -> GoldenRetrieverOutput:
"""
Forward pass of the model.
Args:
questions (`Dict[str, torch.Tensor]`):
The questions to encode.
passages (`Dict[str, torch.Tensor]`):
The passages to encode.
labels (`torch.Tensor`):
The labels of the sentences.
return_loss (`bool`):
Whether to compute the predictions.
question_encodings (`torch.Tensor`):
The encodings of the questions.
passages_encodings (`torch.Tensor`):
The encodings of the passages.
passages_per_question (`List[int]`):
The number of passages per question.
return_loss (`bool`):
Whether to compute the loss.
return_encodings (`bool`):
Whether to return the encodings.
Returns:
obj:`torch.Tensor`: The outputs of the model.
"""
if questions is None and question_encodings is None:
raise ValueError(
"Either `questions` or `question_encodings` must be provided"
)
if passages is None and passages_encodings is None:
raise ValueError(
"Either `passages` or `passages_encodings` must be provided"
)
if question_encodings is None:
question_encodings = self.question_encoder(**questions).pooler_output
if passages_encodings is None:
passages_encodings = self.passage_encoder(**passages).pooler_output
if passages_per_question is not None:
# multiply each question encoding with a passages_per_question encodings
concatenated_passages = torch.stack(
torch.split(passages_encodings, passages_per_question)
).transpose(1, 2)
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
# normalize the encodings for cosine similarity
concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2)
question_encodings = F.normalize(question_encodings, p=2, dim=1)
logits = torch.bmm(
question_encodings.unsqueeze(1), concatenated_passages
).view(question_encodings.shape[0], -1)
else:
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
# normalize the encodings for cosine similarity
question_encodings = F.normalize(question_encodings, p=2, dim=1)
passages_encodings = F.normalize(passages_encodings, p=2, dim=1)
logits = torch.matmul(question_encodings, passages_encodings.T)
output = dict(logits=logits)
if return_loss and labels is not None:
if self.loss_type is None:
raise ValueError(
"If `return_loss` is set to `True`, `loss_type` must be provided"
)
if isinstance(self.loss_type, torch.nn.NLLLoss):
labels = labels.argmax(dim=1)
logits = F.log_softmax(logits, dim=1)
if len(question_encodings.size()) > 1:
logits = logits.view(question_encodings.size(0), -1)
output["loss"] = self.loss_type(logits, labels)
if return_encodings:
output["question_encodings"] = question_encodings
output["passages_encodings"] = passages_encodings
return GoldenRetrieverOutput(**output)
@torch.no_grad()
@torch.inference_mode()
def index(
self,
batch_size: int = 32,
num_workers: int = 4,
max_length: Optional[int] = None,
collate_fn: Optional[Callable] = None,
force_reindex: bool = False,
compute_on_cpu: bool = False,
precision: Optional[Union[str, int]] = None,
):
"""
Index the passages for later retrieval.
Args:
batch_size (`int`):
The batch size to use for the indexing.
num_workers (`int`):
The number of workers to use for the indexing.
max_length (`Optional[int]`):
The maximum length of the passages.
collate_fn (`Callable`):
The collate function to use for the indexing.
force_reindex (`bool`):
Whether to force reindexing even if the passages are already indexed.
compute_on_cpu (`bool`):
Whether to move the index to the CPU after the indexing.
precision (`Optional[Union[str, int]]`):
The precision to use for the model.
"""
if self.document_index is None:
raise ValueError(
"The retriever must be initialized with an indexer to index "
"the passages within the retriever."
)
return self.document_index.index(
retriever=self,
batch_size=batch_size,
num_workers=num_workers,
max_length=max_length,
collate_fn=collate_fn,
encoder_precision=precision or self.precision,
compute_on_cpu=compute_on_cpu,
force_reindex=force_reindex,
)
@torch.no_grad()
@torch.inference_mode()
def retrieve(
self,
text: Optional[Union[str, List[str]]] = None,
text_pair: Optional[Union[str, List[str]]] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
k: Optional[int] = None,
max_length: Optional[int] = None,
precision: Optional[Union[str, int]] = None,
) -> List[List[RetrievedSample]]:
"""
Retrieve the passages for the questions.
Args:
text (`Optional[Union[str, List[str]]]`):
The questions to retrieve the passages for.
text_pair (`Optional[Union[str, List[str]]]`):
The questions to retrieve the passages for.
input_ids (`torch.Tensor`):
The input ids of the questions.
attention_mask (`torch.Tensor`):
The attention mask of the questions.
token_type_ids (`torch.Tensor`):
The token type ids of the questions.
k (`int`):
The number of top passages to retrieve.
max_length (`Optional[int]`):
The maximum length of the questions.
precision (`Optional[Union[str, int]]`):
The precision to use for the model.
Returns:
`List[List[RetrievedSample]]`: The retrieved passages and their indices.
"""
if self.document_index is None:
raise ValueError(
"The indexer must be indexed before it can be used within the retriever."
)
if text is None and input_ids is None:
raise ValueError(
"Either `text` or `input_ids` must be provided to retrieve the passages."
)
if text is not None:
if isinstance(text, str):
text = [text]
if text_pair is not None and isinstance(text_pair, str):
text_pair = [text_pair]
tokenizer = self.question_tokenizer
model_inputs = ModelInputs(
tokenizer(
text,
text_pair=text_pair,
padding=True,
return_tensors="pt",
truncation=True,
max_length=max_length or tokenizer.model_max_length,
)
)
else:
model_inputs = ModelInputs(dict(input_ids=input_ids))
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
if token_type_ids is not None:
model_inputs["token_type_ids"] = token_type_ids
model_inputs.to(self.device)
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
# we need to convert the model device to that
device_type_for_autocast = str(self.device).split(":")[0]
# autocast doesn't work with CPU and stuff different from bfloat16
autocast_pssg_mngr = (
contextlib.nullcontext()
if device_type_for_autocast == "cpu"
else (
torch.autocast(
device_type=device_type_for_autocast,
dtype=PRECISION_MAP[precision],
)
)
)
with autocast_pssg_mngr:
question_encodings = self.question_encoder(**model_inputs).pooler_output
# TODO: fix if encoder and index are on different device
return self.document_index.search(question_encodings, k)
def get_index_from_passage(self, passage: str) -> int:
"""
Get the index of the passage.
Args:
passage (`str`):
The passage to get the index for.
Returns:
`int`: The index of the passage.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_index_from_passage(passage)
def get_passage_from_index(self, index: int) -> str:
"""
Get the passage from the index.
Args:
index (`int`):
The index of the passage.
Returns:
`str`: The passage.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_passage_from_index(index)
def get_vector_from_index(self, index: int) -> torch.Tensor:
"""
Get the passage vector from the index.
Args:
index (`int`):
The index of the passage.
Returns:
`torch.Tensor`: The passage vector.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_embeddings_from_index(index)
def get_vector_from_passage(self, passage: str) -> torch.Tensor:
"""
Get the passage vector from the passage.
Args:
passage (`str`):
The passage.
Returns:
`torch.Tensor`: The passage vector.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_embeddings_from_passage(passage)
@property
def passage_embeddings(self) -> torch.Tensor:
"""
The passage embeddings.
"""
return self._passage_embeddings
@property
def passage_index(self) -> Labels:
"""
The passage index.
"""
return self._passage_index
@property
def device(self) -> torch.device:
"""
The device of the model.
"""
return next(self.parameters()).device
@property
def question_tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The question tokenizer.
"""
if self._question_tokenizer:
return self._question_tokenizer
if (
self.question_encoder.config.name_or_path
== self.question_encoder.config.name_or_path
):
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
self._passage_tokenizer = self._question_tokenizer
return self._question_tokenizer
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
return self._question_tokenizer
@property
def passage_tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The passage tokenizer.
"""
if self._passage_tokenizer:
return self._passage_tokenizer
if (
self.question_encoder.config.name_or_path
== self.passage_encoder.config.name_or_path
):
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
self._passage_tokenizer = self._question_tokenizer
return self._passage_tokenizer
if not self._passage_tokenizer:
self._passage_tokenizer = tr.AutoTokenizer.from_pretrained(
self.passage_encoder.config.name_or_path
)
return self._passage_tokenizer
def save_pretrained(
self,
output_dir: Union[str, os.PathLike],
question_encoder_name: Optional[str] = None,
passage_encoder_name: Optional[str] = None,
document_index_name: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Save the retriever to a directory.
Args:
output_dir (`str`):
The directory to save the retriever to.
question_encoder_name (`Optional[str]`):
The name of the question encoder.
passage_encoder_name (`Optional[str]`):
The name of the passage encoder.
document_index_name (`Optional[str]`):
The name of the document index.
push_to_hub (`bool`):
Whether to push the model to the hub.
"""
# create the output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving retriever to {output_dir}")
question_encoder_name = question_encoder_name or "question_encoder"
passage_encoder_name = passage_encoder_name or "passage_encoder"
document_index_name = document_index_name or "document_index"
logger.info(
f"Saving question encoder state to {output_dir / question_encoder_name}"
)
# self.question_encoder.config._name_or_path = question_encoder_name
self.question_encoder.register_for_auto_class()
self.question_encoder.save_pretrained(
output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs
)
self.question_tokenizer.save_pretrained(
output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs
)
if not self.passage_encoder_is_question_encoder:
logger.info(
f"Saving passage encoder state to {output_dir / passage_encoder_name}"
)
# self.passage_encoder.config._name_or_path = passage_encoder_name
self.passage_encoder.register_for_auto_class()
self.passage_encoder.save_pretrained(
output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
)
self.passage_tokenizer.save_pretrained(
output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
)
if self.document_index is not None:
# save the indexer
self.document_index.save_pretrained(
output_dir / document_index_name, push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving retriever to disk done.")
|