Spaces:
Running
Running
# trainer_manager.py | |
from longtrainer.trainer import LongTrainer | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from config import CONNECTION_STRING, CHATGROQ_API_KEY, CUSTOM_PROMPT | |
def get_embeddings(): | |
# Initialize HuggingFace embeddings with the specified model and parameters | |
model_name = "BAAI/bge-small-en" | |
model_kwargs = {"device": "cpu"} | |
encode_kwargs = {"normalize_embeddings": True} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs | |
) | |
return embeddings | |
def get_llm(): | |
if not CHATGROQ_API_KEY: | |
raise ValueError("CHATGROQ_API_KEY is not set.") | |
llm = ChatGroq( | |
model="llama-3.3-70b-versatile", | |
temperature=0, | |
max_tokens=1024, | |
api_key=CHATGROQ_API_KEY | |
) | |
return llm | |
embedding_model = get_embeddings() | |
llm = get_llm() | |
# Create a global LongTrainer instance | |
trainer_instance = LongTrainer( | |
mongo_endpoint=CONNECTION_STRING, | |
llm=llm, | |
embedding_model=embedding_model, | |
encrypt_chats=True | |
) | |
def get_trainer(): | |
return trainer_instance | |