Daniel Nichols
Squashed commit of the following:
2233c73
raw
history blame
4.06 kB
""" A light wrapper around a bunch of chat LLMs. The class should define a method that takes text input and returns a response from the model.
"""
from abc import ABC, abstractmethod
from typing import Generator, Optional, AsyncGenerator
import os
import random
import openai
import google.generativeai as genai
from llama_cpp import Llama
class ChatModel(ABC):
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return self.name
@abstractmethod
def get_response(self, prompt) -> Generator[str, None, None]:
pass
class DummyModel(ChatModel):
def __init__(self):
super().__init__("dummy")
def get_response(self, prompt: str) -> Generator[str, None, None]:
response = f"Dummy response to: {prompt}"
for idx in range(len(response)):
yield response[:idx+1]
class OpenAIModel(ChatModel):
def __init__(self, model: str, client: openai.OpenAI):
super().__init__(model)
self.model = model
self.client = client
def get_response(self, prompt: str) -> Generator[str, None, None]:
stream = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are PerfGuru, a helpful assistant for assisting developers in identifying performance bottlenecks in their code and optimizing them."},
{"role": "user", "content": prompt}
],
stream=True,
max_tokens=4096,
)
response = ""
for chunk in stream:
response += chunk.choices[0].delta.content or ""
yield response
class GeminiModel(ChatModel):
def __init__(self, model: str, api_key: Optional[str] = None):
super().__init__(model)
if api_key:
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model)
self.config = genai.types.GenerationConfig(
candidate_count=1,
max_output_tokens=4096,
)
def get_response(self, prompt: str) -> Generator[str, None, None]:
stream = self.model.generate_content(prompt, stream=True, generation_config=self.config)
response = ""
for chunk in stream:
response += chunk.text or ""
yield response
class LocalModel(ChatModel):
def __init__(self, model: str, model_path: str):
super().__init__(model)
self.llm = Llama(
model_path=model_path,
n_ctx=8000,
)
def get_response(self, prompt) -> Generator[str, None, None]:
output = self.llm.create_chat_completion(
messages = [
{"role": "system", "content": "You are PerfGuru, a helpful assistant for assisting developers in identifying performance bottlenecks in their code and optimizing them."},
{
"role": "user",
"content": prompt,
}
],
max_tokens=4000,
)
result = output["choices"][0]["message"]["content"]
for idx in range(len(result)):
yield result[:idx+1]
LOCAL_MODELS = [
"Meta-Llama-3-8B-Instruct.Q4_K_S",
]
AVAILABLE_MODELS = [
LocalModel(model_name, f"local_models/{model_name}.gguf")
for model_name in LOCAL_MODELS
]
# AVAILABLE_MODELS.append( DummyModel() )
if os.environ.get("OPENAI_API_KEY"):
openai_client = openai.OpenAI()
AVAILABLE_MODELS.append( OpenAIModel("gpt-4o-mini", openai_client) )
AVAILABLE_MODELS.append( OpenAIModel("gpt-3.5-turbo", openai_client) )
if os.environ.get("GOOGLE_API_KEY"):
AVAILABLE_MODELS.append( GeminiModel("gemini-1.5-flash") )
AVAILABLE_MODELS.append( GeminiModel("gemini-1.5-pro") )
if not AVAILABLE_MODELS:
raise ValueError("No models available. Please set OPENAI_API_KEY or GOOGLE_API_KEY environment variables.")
def select_random_model() -> ChatModel:
return random.choice(AVAILABLE_MODELS)