File size: 3,612 Bytes
8d64162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49cde8e
01ed12d
 
 
 
8d64162
 
01ed12d
8d64162
 
 
 
 
 
 
 
01ed12d
8d64162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35cb430
 
8d64162
 
35cb430
8d64162
 
 
 
 
 
35cb430
 
 
 
 
8d64162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from enum import Enum
from typing import Any, Optional, Union

import instructor
import weave
from PIL import Image

from ..utils import base64_encode_image


class ClientType(Enum, str):
    GEMINI = "gemini"
    MISTRAL = "mistral"


class LLMClient(weave.Model):
    model_name: str
    client_type: ClientType

    def __init__(self, model_name: str, client_type: ClientType):
        super().__init__(model_name=model_name, client_type=client_type)

    @weave.op()
    def execute_gemini_sdk(
        self,
        user_prompt: Union[str, list[str]],
        system_prompt: Optional[Union[str, list[str]]] = None,
        schema: Optional[Any] = None,
    ) -> Union[str, Any]:
        import google.generativeai as genai

        system_prompt = (
            [system_prompt] if isinstance(system_prompt, str) else system_prompt
        )
        user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt

        genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
        model = genai.GenerativeModel(self.model_name)
        generation_config = (
            None
            if schema is None
            else genai.GenerationConfig(
                response_mime_type="application/json", response_schema=list[schema]
            )
        )
        response = model.generate_content(
            system_prompt + user_prompt, generation_config=generation_config
        )
        return response.text if schema is None else response

    @weave.op()
    def execute_mistral_sdk(
        self,
        user_prompt: Union[str, list[str]],
        system_prompt: Optional[Union[str, list[str]]] = None,
        schema: Optional[Any] = None,
    ) -> Union[str, Any]:
        from mistralai import Mistral

        system_prompt = (
            [system_prompt] if isinstance(system_prompt, str) else system_prompt
        )
        user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
        system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
        user_messages = []
        for prompt in user_prompt:
            if isinstance(prompt, Image.Image):
                user_messages.append(
                    {
                        "type": "image_url",
                        "image_url": base64_encode_image(prompt, "image/png"),
                    }
                )
            else:
                user_messages.append({"type": "text", "text": prompt})
        messages = [
            {"role": "system", "content": system_messages},
            {"role": "user", "content": user_messages},
        ]

        client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
        client = instructor.from_mistral(client)

        response = (
            client.chat.complete(model=self.model_name, messages=messages)
            if schema is None
            else client.messages.create(
                response_model=schema, messages=messages, temperature=0
            )
        )
        return response.choices[0].message.content

    @weave.op()
    def predict(
        self,
        user_prompt: Union[str, list[str]],
        system_prompt: Optional[Union[str, list[str]]] = None,
        schema: Optional[Any] = None,
    ) -> Union[str, Any]:
        if self.client_type == ClientType.GEMINI:
            return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
        elif self.client_type == ClientType.MISTRAL:
            return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
        else:
            raise ValueError(f"Invalid client type: {self.client_type}")