File size: 3,433 Bytes
2a0aa5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43c8549
 
 
2a0aa5a
9e789e7
 
 
 
2a0aa5a
 
 
 
 
 
 
 
 
43c8549
 
 
2a0aa5a
 
 
 
43c8549
 
 
 
9e789e7
 
 
 
 
 
 
 
 
 
 
 
2a0aa5a
 
 
 
 
 
 
 
 
 
 
 
43c8549
5aa1748
 
 
43c8549
 
 
 
2a0aa5a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
This module contains functions to interact with the models.
"""

import json
import os
from typing import List

from google.cloud import secretmanager
from google.oauth2 import service_account
import litellm

from credentials import get_credentials_json

GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")

secretmanager_client = secretmanager.SecretManagerServiceClient(
    credentials=service_account.Credentials.from_service_account_info(
        get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
    name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
                                                  MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")

supported_models_json = json.loads(decoded_secret)

DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the language of the text."  # pylint: disable=line-too-long
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}."  # pylint: disable=line-too-long


class ContextWindowExceededError(Exception):
  pass


class Model:

  def __init__(
      self,
      name: str,
      provider: str = None,
      # The JSON keys are in camelCase. To unpack these keys into
      # Model attributes, we need to use the same camelCase names.
      apiKey: str = None,  # pylint: disable=invalid-name
      apiBase: str = None,  # pylint: disable=invalid-name
      summarizeInstruction: str = None,  # pylint: disable=invalid-name
      translateInstruction: str = None):  # pylint: disable=invalid-name
    self.name = name
    self.provider = provider
    self.api_key = apiKey
    self.api_base = apiBase
    self.summarize_instruction = summarizeInstruction or DEFAULT_SUMMARIZE_INSTRUCTION  # pylint: disable=line-too-long
    self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION  # pylint: disable=line-too-long

  def completion(self, messages: List, max_tokens: float = None) -> str:
    try:
      response = litellm.completion(model=self.provider + "/" +
                                    self.name if self.provider else self.name,
                                    api_key=self.api_key,
                                    api_base=self.api_base,
                                    messages=messages,
                                    max_tokens=max_tokens)

      return response.choices[0].message.content

    except litellm.ContextWindowExceededError as e:
      raise ContextWindowExceededError() from e


supported_models: List[Model] = [
    Model(name=model_name, **model_config)
    for model_name, model_config in supported_models_json.items()
]


def check_models(models: List[Model]):
  for model in models:
    print(f"Checking model {model.name}...")
    try:
      model.completion(messages=[{
          "role": "system",
          "content": "You are a kind person."
      }, {
          "role": "user",
          "content": "Hello."
      }],
                       max_tokens=5)
      print(f"Model {model.name} is available.")

    # This check is designed to verify the availability of the models
    # without any issues. Therefore, we need to catch all exceptions.
    except Exception as e:  # pylint: disable=broad-except
      raise RuntimeError(f"Model {model.name} is not available: {e}") from e