Daniel Nichols
remove gemma from list of models
903c035
raw
history blame
6.43 kB
""" A light wrapper around a bunch of chat LLMs. The class should define a method that takes text input and returns a response from the model.
"""
from abc import ABC, abstractmethod
from typing import Generator, Optional, AsyncGenerator
import os
import random
import glob
import openai
import google.generativeai as genai
from llama_cpp import Llama
from huggingface_hub import InferenceClient
class ChatModel(ABC):
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return self.name
@abstractmethod
def get_response(self, prompt) -> Generator[str, None, None]:
pass
class DummyModel(ChatModel):
def __init__(self):
super().__init__("dummy")
def get_response(self, prompt: str) -> Generator[str, None, None]:
response = f"Dummy response to: {prompt}"
for idx in range(len(response)):
yield response[:idx+1]
class OpenAIModel(ChatModel):
def __init__(self, model: str, client: openai.OpenAI):
super().__init__(model)
self.model = model
self.client = client
def get_response(self, prompt: str) -> Generator[str, None, None]:
stream = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are PerfGuru, a helpful assistant for assisting developers in identifying performance bottlenecks in their code and optimizing them."},
{"role": "user", "content": prompt}
],
stream=True,
max_tokens=4096,
)
response = ""
for chunk in stream:
response += chunk.choices[0].delta.content or ""
yield response
class GeminiModel(ChatModel):
def __init__(self, model: str, api_key: Optional[str] = None):
super().__init__(model)
if api_key:
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model)
self.config = genai.types.GenerationConfig(
candidate_count=1,
max_output_tokens=4096,
)
def get_response(self, prompt: str) -> Generator[str, None, None]:
stream = self.model.generate_content(prompt, stream=True, generation_config=self.config)
response = ""
for chunk in stream:
response += chunk.text or ""
yield response
class LocalModel(ChatModel):
def __init__(self, model: str, model_path: str):
super().__init__(model)
self.llm = Llama(
model_path=model_path,
n_ctx=4096,
)
def get_response(self, prompt) -> Generator[str, None, None]:
outputs = self.llm.create_chat_completion(
messages = [
{"role": "system", "content": "You are PerfGuru, a helpful assistant for assisting developers in identifying performance bottlenecks in their code and optimizing them."},
{
"role": "user",
"content": prompt,
}
],
max_tokens=4000,
stream=True,
)
response = ""
for chunk in outputs:
response += chunk['choices'][0]['delta'].get('content', '')
yield response
class InferenceHubModel(ChatModel):
def __init__(self, model: str, client: InferenceClient, supports_system_messages: bool = True):
super().__init__(model)
self.model = model
self.client = client
self.supports_system_messages = supports_system_messages
def get_response(self, prompt: str) -> Generator[str, None, None]:
messages = []
if self.supports_system_messages:
messages.append({"role": "system", "content": "You are PerfGuru, a helpful assistant for assisting developers in identifying performance bottlenecks in their code and optimizing them."})
messages.append({"role": "user", "content": prompt})
stream = self.client.chat.completions.create(
model=self.model,
messages=messages,
stream=True,
max_tokens=2048,
)
response = ""
for chunk in stream:
response += chunk.choices[0].delta.content or ""
yield response
AVAILABLE_MODELS = []
if os.environ.get("USE_LOCAL_MODELS") == "1":
HF_HOME = os.environ.get("HF_HOME", "/home/user/.cache/huggingface")
GGUF_WILDCARD = os.path.join(HF_HOME, "hub", "models-*", "**", "*.gguf")
GGUF_PATHS = [(os.path.basename(p), p) for p in glob.glob(GGUF_WILDCARD, recursive=True)]
LOCAL_MODEL_PATHS = [(os.path.basename(p), p) for p in glob.glob(os.path.join("local_models", "*.gguf"))]
ALL_LOCAL_MODELS = GGUF_PATHS + LOCAL_MODEL_PATHS
AVAILABLE_MODELS.extend([
LocalModel(model_name, model_path)
for model_name, model_path in ALL_LOCAL_MODELS
if os.path.exists(model_path)
])
# AVAILABLE_MODELS.append( DummyModel() )
if os.environ.get("OPENAI_API_KEY"):
openai_client = openai.OpenAI()
AVAILABLE_MODELS.append( OpenAIModel("gpt-4o-mini", openai_client) )
AVAILABLE_MODELS.append( OpenAIModel("gpt-3.5-turbo", openai_client) )
if os.environ.get("GOOGLE_API_KEY"):
AVAILABLE_MODELS.append( GeminiModel("gemini-1.5-flash") )
AVAILABLE_MODELS.append( GeminiModel("gemini-1.5-pro") )
if os.environ.get("HF_API_KEY"):
hf_inference_client = InferenceClient(api_key=os.environ.get("HF_API_KEY"), timeout=60)
#AVAILABLE_MODELS.append( InferenceHubModel("google/gemma-2-2b-it", hf_inference_client, supports_system_messages=False) )
#AVAILABLE_MODELS.append( InferenceHubModel("Qwen/Qwen2.5-7B-Instruct", hf_inference_client) )
AVAILABLE_MODELS.append( InferenceHubModel("microsoft/Phi-3-mini-4k-instruct", hf_inference_client) )
AVAILABLE_MODELS.append( InferenceHubModel("meta-llama/Llama-3.2-1B-Instruct", hf_inference_client) )
AVAILABLE_MODELS.append( InferenceHubModel("meta-llama/Llama-3.2-3B-Instruct", hf_inference_client) )
AVAILABLE_MODELS.append( InferenceHubModel("meta-llama/Meta-Llama-3.1-8B-Instruct", hf_inference_client) )
if not AVAILABLE_MODELS:
raise ValueError("No models available. Please set OPENAI_API_KEY or GOOGLE_API_KEY environment variables.")
def select_random_model() -> ChatModel:
return random.choice(AVAILABLE_MODELS)