File size: 2,567 Bytes
2a0aa5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
This module contains functions to interact with the models.
"""

import json
import os
from typing import List

from google.cloud import secretmanager
from google.oauth2 import service_account
import litellm

from credentials import get_credentials_json

GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")

secretmanager_client = secretmanager.SecretManagerServiceClient(
    credentials=service_account.Credentials.from_service_account_info(
        get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
    name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
                                                  MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")

supported_models_json = json.loads(decoded_secret)


class Model:

  def __init__(
      self,
      name: str,
      provider: str = None,
      # The JSON keys are in camelCase. To unpack these keys into
      # Model attributes, we need to use the same camelCase names.
      apiKey: str = None,  # pylint: disable=invalid-name
      apiBase: str = None):  # pylint: disable=invalid-name
    self.name = name
    self.provider = provider
    self.api_key = apiKey
    self.api_base = apiBase


supported_models: List[Model] = [
    Model(name=model_name, **model_config)
    for model_name, model_config in supported_models_json.items()
]


def completion(model: Model, messages: List, max_tokens: float = None) -> str:
  response = litellm.completion(model=model.provider + "/" +
                                model.name if model.provider else model.name,
                                api_key=model.api_key,
                                api_base=model.api_base,
                                messages=messages,
                                max_tokens=max_tokens)

  return response.choices[0].message.content


def check_models(models: List[Model]):
  for model in models:
    print(f"Checking model {model.name}...")
    try:
      completion(model=model,
                 messages=[{
                     "content": "Hello.",
                     "role": "user"
                 }],
                 max_tokens=5)
      print(f"Model {model.name} is available.")

    # This check is designed to verify the availability of the models
    # without any issues. Therefore, we need to catch all exceptions.
    except Exception as e:  # pylint: disable=broad-except
      raise RuntimeError(f"Model {model.name} is not available: {e}") from e