""" This module contains functions to interact with the models. """ import json import os from typing import List from google.cloud import secretmanager from google.oauth2 import service_account import litellm from credentials import get_credentials_json GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT") MODELS_SECRET = os.environ.get("MODELS_SECRET") secretmanager_client = secretmanager.SecretManagerServiceClient( credentials=service_account.Credentials.from_service_account_info( get_credentials_json())) models_secret = secretmanager_client.access_secret_version( name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT, MODELS_SECRET, "latest")) decoded_secret = models_secret.payload.data.decode("UTF-8") supported_models_json = json.loads(decoded_secret) class Model: def __init__( self, name: str, provider: str = None, # The JSON keys are in camelCase. To unpack these keys into # Model attributes, we need to use the same camelCase names. apiKey: str = None, # pylint: disable=invalid-name apiBase: str = None): # pylint: disable=invalid-name self.name = name self.provider = provider self.api_key = apiKey self.api_base = apiBase supported_models: List[Model] = [ Model(name=model_name, **model_config) for model_name, model_config in supported_models_json.items() ] def completion(model: Model, messages: List, max_tokens: float = None) -> str: response = litellm.completion(model=model.provider + "/" + model.name if model.provider else model.name, api_key=model.api_key, api_base=model.api_base, messages=messages, max_tokens=max_tokens) return response.choices[0].message.content def check_models(models: List[Model]): for model in models: print(f"Checking model {model.name}...") try: completion(model=model, messages=[{ "content": "Hello.", "role": "user" }], max_tokens=5) print(f"Model {model.name} is available.") # This check is designed to verify the availability of the models # without any issues. Therefore, we need to catch all exceptions. except Exception as e: # pylint: disable=broad-except raise RuntimeError(f"Model {model.name} is not available: {e}") from e