mominah commited on
Commit
8176bbf
·
verified ·
1 Parent(s): 10e6638

Create trainer_manager.py

Browse files
Files changed (1) hide show
  1. trainer_manager.py +40 -0
trainer_manager.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # trainer_manager.py
2
+ from longtrainer.trainer import LongTrainer
3
+ from langchain_groq import ChatGroq
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from config import CONNECTION_STRING, CHATGROQ_API_KEY, CUSTOM_PROMPT
6
+
7
+ def get_embeddings():
8
+ # Initialize HuggingFace embeddings with the specified model and parameters
9
+ model_name = "BAAI/bge-small-en"
10
+ model_kwargs = {"device": "cpu"}
11
+ encode_kwargs = {"normalize_embeddings": True}
12
+ embeddings = HuggingFaceEmbeddings(
13
+ model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
14
+ )
15
+ return embeddings
16
+
17
+ def get_llm():
18
+ if not CHATGROQ_API_KEY:
19
+ raise ValueError("CHATGROQ_API_KEY is not set.")
20
+ llm = ChatGroq(
21
+ model="llama-3.3-70b-versatile",
22
+ temperature=0,
23
+ max_tokens=1024,
24
+ api_key=CHATGROQ_API_KEY
25
+ )
26
+ return llm
27
+
28
+ embedding_model = get_embeddings()
29
+ llm = get_llm()
30
+
31
+ # Create a global LongTrainer instance
32
+ trainer_instance = LongTrainer(
33
+ mongo_endpoint=CONNECTION_STRING,
34
+ llm=llm,
35
+ embedding_model=embedding_model,
36
+ encrypt_chats=True
37
+ )
38
+
39
+ def get_trainer():
40
+ return trainer_instance