Zeel commited on
Commit
042ce46
1 Parent(s): 339d432

Update src.py

Browse files
Files changed (1) hide show
  1. src.py +16 -6
src.py CHANGED
@@ -6,12 +6,14 @@ from PIL import Image
6
  from pandasai.llm import HuggingFaceTextGen
7
  from dotenv import load_dotenv
8
  from langchain_groq.chat_models import ChatGroq
 
9
 
10
  load_dotenv()
11
  Groq_Token = os.environ["GROQ_API_KEY"]
12
- models = {"mixtral": "mixtral-8x7b-32768", "llama": "llama2-70b-4096", "gemma": "gemma-7b-it"}
13
 
14
  hf_token = os.getenv("HF_READ")
 
15
 
16
  def preprocess_and_load_df(path: str) -> pd.DataFrame:
17
  df = pd.read_csv(path)
@@ -27,7 +29,10 @@ def load_agent(df: pd.DataFrame, context: str, inference_server: str, name="mixt
27
  # top_k=5,
28
  # )
29
  # llm.client.headers = {"Authorization": f"Bearer {hf_token}"}
30
- llm = ChatGroq(model=models[name], api_key=os.getenv("GROQ_API"), temperature=0.1)
 
 
 
31
 
32
  agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
33
  agent.add_message(context)
@@ -86,7 +91,10 @@ def show_response(st, response):
86
  return {"is_image": False}
87
 
88
  def ask_question(model_name, question):
89
- llm = ChatGroq(model=models[model_name], api_key=os.getenv("GROQ_API"), temperature=0.1)
 
 
 
90
 
91
  df_check = pd.read_csv("Data.csv")
92
  df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
@@ -121,11 +129,13 @@ df["Timestamp"] = pd.to_datetime(df["Timestamp"])
121
  {template}
122
 
123
  """
124
-
125
- answer = llm.invoke(query)
 
 
126
  code = f"""
127
  {template.split("```python")[1].split("```")[0]}
128
- {answer.content.split("```python")[1].split("```")[0]}
129
  """
130
  # update variable `answer` when code is executed
131
  exec(code)
 
6
  from pandasai.llm import HuggingFaceTextGen
7
  from dotenv import load_dotenv
8
  from langchain_groq.chat_models import ChatGroq
9
+ from langchain_google_genai import GoogleGenerativeAI
10
 
11
  load_dotenv()
12
  Groq_Token = os.environ["GROQ_API_KEY"]
13
+ models = {"mixtral": "mixtral-8x7b-32768", "llama": "llama2-70b-4096", "gemma": "gemma-7b-it", "gemini-pro": "gemini-pro"}
14
 
15
  hf_token = os.getenv("HF_READ")
16
+ gemini_token = os.getenv("GEMINI_TOKEN")
17
 
18
  def preprocess_and_load_df(path: str) -> pd.DataFrame:
19
  df = pd.read_csv(path)
 
29
  # top_k=5,
30
  # )
31
  # llm.client.headers = {"Authorization": f"Bearer {hf_token}"}
32
+ if name == "gemini-pro":
33
+ llm = GoogleGenerativeAI(model=model, google_api_key=gemini_token, temperature=0.1)
34
+ else:
35
+ llm = ChatGroq(model=models[name], api_key=os.getenv("GROQ_API"), temperature=0.1)
36
 
37
  agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
38
  agent.add_message(context)
 
91
  return {"is_image": False}
92
 
93
  def ask_question(model_name, question):
94
+ if model_name == "gemini-pro":
95
+ llm = GoogleGenerativeAI(model=model, google_api_key=os.environ.get("GOOGLE_API_KEY"), temperature=0)
96
+ else:
97
+ llm = ChatGroq(model=models[model_name], api_key=os.getenv("GROQ_API"), temperature=0.1)
98
 
99
  df_check = pd.read_csv("Data.csv")
100
  df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
 
129
  {template}
130
 
131
  """
132
+ if model_name == "gemini-pro":
133
+ answer = llm.invoke(query)
134
+ else:
135
+ answer = llm.invoke(query).content
136
  code = f"""
137
  {template.split("```python")[1].split("```")[0]}
138
+ {answer.split("```python")[1].split("```")[0]}
139
  """
140
  # update variable `answer` when code is executed
141
  exec(code)