""" Main file to execute the TRS Pipeline. """ import sys from augmentation import prompt_generation as pg from information_retrieval import info_retrieval as ir from text_generation.models import ( Llama3, Mistral, Gemma2, Llama3Point1, Llama3Instruct, MistralInstruct, Llama3Point1Instruct, Phi3SmallInstruct, GPT4, Gemini, Claude3Point5Sonnet, ) from text_generation import text_generation as tg import logging logger = logging.getLogger(__name__) logging.basicConfig(encoding='utf-8', level=logging.DEBUG) from src.text_generation.mapper import MODEL_MAPPER from src.post_processing.post_process import post_process_output TEST_DIR = "../tests/" MODELS = { 'GPT-4': GPT4, 'Llama3': Llama3, 'Mistral': Mistral, 'Gemma2': Gemma2, 'Llama3.1': Llama3Point1, 'Llama3-Instruct': Llama3Instruct, 'Mistral-Instruct': MistralInstruct, 'Llama3.1-Instruct': Llama3Point1Instruct, 'Phi3-Instruct': Phi3SmallInstruct, "Gemini-1.0-pro": Gemini, "Claude3.5-sonnet": Claude3Point5Sonnet, } def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **params): """ Executes the entire RAG pipeline, provided the query and model class name. Args: - query: str - model_name: string, one of the following: Llama3, Mistral, Gemma2, Llama3Point1 - test: whether the pipeline is running a test - params: - limit (number of results to be retained) - reranking (binary, whether to rerank results using ColBERT or not) - sustainability """ try: model_id = MODEL_MAPPER[model_name] except KeyError: logger.error(f"Model {model_name} not found in the model mapper.") model_id = MODEL_MAPPER['Gemini-1.0-pro'] context_params = { 'limit': 5, 'reranking': 0, 'sustainability': 0, } if 'limit' in params: context_params['limit'] = params['limit'] if 'reranking' in params: context_params['reranking'] = params['reranking'] if 'sustainability' in params: context_params['sustainability'] = params['sustainability'] logger.info("Retrieving context..") try: context = ir.get_context(starting_point=starting_point, query=query, **context_params) if test: retrieved_cities = ir.get_cities(context) else: retrieved_cities = None except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() logger.error(f"Error at line {exc_tb.tb_lineno} while trying to get context: {e}") return None logger.info("Retrieved context, augmenting prompt..") try: prompt = pg.augment_prompt( query=query, starting_point=starting_point, context=context, params=context_params ) except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() logger.error(f"Error at line {exc_tb.tb_lineno} while trying to augment prompt: {e}") return None # return prompt logger.info(f"Augmented prompt, initializing {model_name} and generating response..") try: response = tg.generate_response(model_id, prompt, **params) except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}") return None try: model_params = {"max_tokens": params["max_tokens"], "temperature": params["temperature"]} post_processed_response = post_process_output( model_id=model_id, user_query=query, starting_point=starting_point, context=context, response=response, **model_params) except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}") return None if test: return retrieved_cities, prompt[1]['content'], post_processed_response else: return post_processed_response if __name__ == "__main__": # sample_query = "I'm planning a trip in the summer and I love art, history, and visiting museums. Can you # suggest " \ "some " \ "European cities? " sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \ "cities. " model_name = "GPT-4" pipeline_response = pipeline( query=sample_query, model_name=model_name, sustainability=1 ) print(pipeline_response)