File size: 1,159 Bytes
8176bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# trainer_manager.py
from longtrainer.trainer import LongTrainer
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from config import CONNECTION_STRING, CHATGROQ_API_KEY, CUSTOM_PROMPT

def get_embeddings():
    # Initialize HuggingFace embeddings with the specified model and parameters
    model_name = "BAAI/bge-small-en"
    model_kwargs = {"device": "cpu"}
    encode_kwargs = {"normalize_embeddings": True}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
    )
    return embeddings

def get_llm():
    if not CHATGROQ_API_KEY:
        raise ValueError("CHATGROQ_API_KEY is not set.")
    llm = ChatGroq(
        model="llama-3.3-70b-versatile",
        temperature=0,
        max_tokens=1024,
        api_key=CHATGROQ_API_KEY
    )
    return llm

embedding_model = get_embeddings()
llm = get_llm()

# Create a global LongTrainer instance
trainer_instance = LongTrainer(
    mongo_endpoint=CONNECTION_STRING,
    llm=llm,
    embedding_model=embedding_model,
    encrypt_chats=True
)

def get_trainer():
    return trainer_instance