Spaces:
Running
Running
File size: 3,612 Bytes
8d64162 49cde8e 01ed12d 8d64162 01ed12d 8d64162 01ed12d 8d64162 35cb430 8d64162 35cb430 8d64162 35cb430 8d64162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
from enum import Enum
from typing import Any, Optional, Union
import instructor
import weave
from PIL import Image
from ..utils import base64_encode_image
class ClientType(Enum, str):
GEMINI = "gemini"
MISTRAL = "mistral"
class LLMClient(weave.Model):
model_name: str
client_type: ClientType
def __init__(self, model_name: str, client_type: ClientType):
super().__init__(model_name=model_name, client_type=client_type)
@weave.op()
def execute_gemini_sdk(
self,
user_prompt: Union[str, list[str]],
system_prompt: Optional[Union[str, list[str]]] = None,
schema: Optional[Any] = None,
) -> Union[str, Any]:
import google.generativeai as genai
system_prompt = (
[system_prompt] if isinstance(system_prompt, str) else system_prompt
)
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
model = genai.GenerativeModel(self.model_name)
generation_config = (
None
if schema is None
else genai.GenerationConfig(
response_mime_type="application/json", response_schema=list[schema]
)
)
response = model.generate_content(
system_prompt + user_prompt, generation_config=generation_config
)
return response.text if schema is None else response
@weave.op()
def execute_mistral_sdk(
self,
user_prompt: Union[str, list[str]],
system_prompt: Optional[Union[str, list[str]]] = None,
schema: Optional[Any] = None,
) -> Union[str, Any]:
from mistralai import Mistral
system_prompt = (
[system_prompt] if isinstance(system_prompt, str) else system_prompt
)
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
user_messages = []
for prompt in user_prompt:
if isinstance(prompt, Image.Image):
user_messages.append(
{
"type": "image_url",
"image_url": base64_encode_image(prompt, "image/png"),
}
)
else:
user_messages.append({"type": "text", "text": prompt})
messages = [
{"role": "system", "content": system_messages},
{"role": "user", "content": user_messages},
]
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
client = instructor.from_mistral(client)
response = (
client.chat.complete(model=self.model_name, messages=messages)
if schema is None
else client.messages.create(
response_model=schema, messages=messages, temperature=0
)
)
return response.choices[0].message.content
@weave.op()
def predict(
self,
user_prompt: Union[str, list[str]],
system_prompt: Optional[Union[str, list[str]]] = None,
schema: Optional[Any] = None,
) -> Union[str, Any]:
if self.client_type == ClientType.GEMINI:
return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
elif self.client_type == ClientType.MISTRAL:
return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
else:
raise ValueError(f"Invalid client type: {self.client_type}")
|