import os import sys from langchain_anthropic import ChatAnthropic from langchain_fireworks import ChatFireworks from langchain_google_vertexai import ChatVertexAI from langchain_openai import ChatOpenAI sys.path.append(os.getcwd()) import KEYS from research_assistant.app_logging import app_logger def set_api_key(env_var: str, api_key: str): os.environ[env_var] = api_key class Agent: def __init__(self, model_name: str): model_classes = { "gpt": ( (ChatOpenAI, "OPENAI_API_KEY", KEYS.OPENAI) # type: ignore if "OPENAI" in KEYS.__dict__ else (None, None, None) ), "claude": ( (ChatAnthropic, "ANTHROPIC_API_KEY", KEYS.ANTHROPIC) # type: ignore if "ANTHROPIC" in KEYS.__dict__ else (None, None, None) ), "gemini": ( (ChatVertexAI, "GOOGLE_API_KEY", KEYS.VERTEX_AI) # type: ignore if "VERTEX_AI" in KEYS.__dict__ else (None, None, None) ), "fireworks": ( (ChatFireworks, "FIREWORKS_API_KEY", KEYS.FIREWORKS_AI) # type: ignore if "FIREWORKS_AI" in KEYS.__dict__ else (None, None, None) ), } max_tokens_map = { "gpt-3.5": 16000, "gpt-4": 8000, "gpt-4o-mini": 8000, "llama-v3p2-1b-instruct": 128000, "llama-v3p2-3b-instruct": 128000, "llama-v3p1-8b-instruct": 128000, "llama-v3p1-70b-instruct": 128000, "llama-v3p1-405b-instruct": 128000, "mixtral-8x22b-instruct": 64000, "mixtral-8x7b-instruct": 32000, "mixtral-8x7b-instruct-hf": 32000, "qwen2p5-72b-instruct": 32000, "gemma2-9b-it": 8000, "llama-v3-8b-instruct": 8000, "llama-v3-70b-instruct": 8000, "llama-v3-70b-instruct-hf": 8000, } for key, (model_class, env_var, api_key) in model_classes.items(): if model_class is not None and key in model_name: set_api_key(env_var, api_key) # type: ignore model = model_class(model=model_name, temperature=0.5) # type: ignore max_tokens = max_tokens_map.get(model_name, 128000) break else: raise ValueError(f"Model {model_name} not supported") app_logger.info(f"Model {model_name} is initialized successfully") self.model = model self.max_tokens = max_tokens def get_model(self): return self.model