import logging from pathlib import Path from typing import List, Optional, Union from relik.common.utils import is_package_available if not is_package_available("fastapi"): raise ImportError( "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." ) from fastapi import FastAPI, HTTPException if not is_package_available("ray"): raise ImportError( "Ray is not installed. Please install Ray with `pip install relik[serve]`." ) from ray import serve from relik.common.log import get_logger from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer from relik.inference.data.window.manager import WindowManager from relik.inference.serve.backend.utils import ( RayParameterManager, ServerParameterManager, ) from relik.retriever.data.utils import batch_generator from relik.retriever.pytorch_modules import GoldenRetriever logger = get_logger(__name__, level=logging.INFO) VERSION = {} # type: ignore with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file: exec(version_file.read(), VERSION) # Env variables for server SERVER_MANAGER = ServerParameterManager() RAY_MANAGER = RayParameterManager() app = FastAPI( title="Golden Retriever", version=VERSION["VERSION"], description="Golden Retriever REST API", ) @serve.deployment( ray_actor_options={ "num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0 }, autoscaling_config={ "min_replicas": RAY_MANAGER.min_replicas, "max_replicas": RAY_MANAGER.max_replicas, }, ) @serve.ingress(app) class GoldenRetrieverServer: def __init__( self, question_encoder: str, document_index: str, passage_encoder: Optional[str] = None, top_k: int = 100, device: str = "cpu", index_device: Optional[str] = None, precision: int = 32, index_precision: Optional[int] = None, use_faiss: bool = False, window_batch_size: int = 32, window_size: int = 32, window_stride: int = 16, split_on_spaces: bool = False, ): # parameters self.question_encoder = question_encoder self.passage_encoder = passage_encoder self.document_index = document_index self.top_k = top_k self.device = device self.index_device = index_device or device self.precision = precision self.index_precision = index_precision or precision self.use_faiss = use_faiss self.window_batch_size = window_batch_size self.window_size = window_size self.window_stride = window_stride self.split_on_spaces = split_on_spaces # log stuff for debugging logger.info("Initializing GoldenRetrieverServer with parameters:") logger.info(f"QUESTION_ENCODER: {self.question_encoder}") logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") logger.info(f"DOCUMENT_INDEX: {self.document_index}") logger.info(f"TOP_K: {self.top_k}") logger.info(f"DEVICE: {self.device}") logger.info(f"INDEX_DEVICE: {self.index_device}") logger.info(f"PRECISION: {self.precision}") logger.info(f"INDEX_PRECISION: {self.index_precision}") logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") self.retriever = GoldenRetriever( question_encoder=self.question_encoder, passage_encoder=self.passage_encoder, document_index=self.document_index, device=self.device, index_device=self.index_device, index_precision=self.index_precision, ) self.retriever.eval() if self.split_on_spaces: logger.info("Using WhitespaceTokenizer") self.tokenizer = WhitespaceTokenizer() # logger.info("Using RegexTokenizer") # self.tokenizer = RegexTokenizer() else: logger.info("Using SpacyTokenizer") self.tokenizer = SpacyTokenizer(language="en") self.window_manager = WindowManager(tokenizer=self.tokenizer) # @serve.batch() async def handle_batch( self, documents: List[str], document_topics: List[str] ) -> List: return self.retriever.retrieve( documents, text_pair=document_topics, k=self.top_k, precision=self.precision ) @app.post("/api/retrieve") async def retrieve_endpoint( self, documents: Union[str, List[str]], document_topics: Optional[Union[str, List[str]]] = None, ): try: # normalize input if isinstance(documents, str): documents = [documents] if document_topics is not None: if isinstance(document_topics, str): document_topics = [document_topics] assert len(documents) == len(document_topics) # get predictions return await self.handle_batch(documents, document_topics) except Exception as e: # log the entire stack trace logger.exception(e) raise HTTPException(status_code=500, detail=f"Server Error: {e}") @app.post("/api/gerbil") async def gerbil_endpoint(self, documents: Union[str, List[str]]): try: # normalize input if isinstance(documents, str): documents = [documents] # output list windows_passages = [] # split documents into windows document_windows = [ window for doc_id, document in enumerate(documents) for window in self.window_manager( self.tokenizer, document, window_size=self.window_size, stride=self.window_stride, doc_id=doc_id, ) ] # get text and topic from document windows and create new list model_inputs = [ (window.text, window.doc_topic) for window in document_windows ] # batch generator for batch in batch_generator( model_inputs, batch_size=self.window_batch_size ): text, text_pair = zip(*batch) batch_predictions = await self.handle_batch(text, text_pair) windows_passages.extend( [ [p.label for p in predictions] for predictions in batch_predictions ] ) # add passage to document windows for window, passages in zip(document_windows, windows_passages): # clean up passages (remove everything after first tag if present) passages = [c.split(" ", 1)[0] for c in passages] window.window_candidates = passages # return document windows return document_windows except Exception as e: # log the entire stack trace logger.exception(e) raise HTTPException(status_code=500, detail=f"Server Error: {e}") server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))