Update src.py
Browse files
src.py
CHANGED
@@ -6,12 +6,14 @@ from PIL import Image
|
|
6 |
from pandasai.llm import HuggingFaceTextGen
|
7 |
from dotenv import load_dotenv
|
8 |
from langchain_groq.chat_models import ChatGroq
|
|
|
9 |
|
10 |
load_dotenv()
|
11 |
Groq_Token = os.environ["GROQ_API_KEY"]
|
12 |
-
models = {"mixtral": "mixtral-8x7b-32768", "llama": "llama2-70b-4096", "gemma": "gemma-7b-it"}
|
13 |
|
14 |
hf_token = os.getenv("HF_READ")
|
|
|
15 |
|
16 |
def preprocess_and_load_df(path: str) -> pd.DataFrame:
|
17 |
df = pd.read_csv(path)
|
@@ -27,7 +29,10 @@ def load_agent(df: pd.DataFrame, context: str, inference_server: str, name="mixt
|
|
27 |
# top_k=5,
|
28 |
# )
|
29 |
# llm.client.headers = {"Authorization": f"Bearer {hf_token}"}
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
|
33 |
agent.add_message(context)
|
@@ -86,7 +91,10 @@ def show_response(st, response):
|
|
86 |
return {"is_image": False}
|
87 |
|
88 |
def ask_question(model_name, question):
|
89 |
-
|
|
|
|
|
|
|
90 |
|
91 |
df_check = pd.read_csv("Data.csv")
|
92 |
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
|
@@ -121,11 +129,13 @@ df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
|
121 |
{template}
|
122 |
|
123 |
"""
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
code = f"""
|
127 |
{template.split("```python")[1].split("```")[0]}
|
128 |
-
{answer.
|
129 |
"""
|
130 |
# update variable `answer` when code is executed
|
131 |
exec(code)
|
|
|
6 |
from pandasai.llm import HuggingFaceTextGen
|
7 |
from dotenv import load_dotenv
|
8 |
from langchain_groq.chat_models import ChatGroq
|
9 |
+
from langchain_google_genai import GoogleGenerativeAI
|
10 |
|
11 |
load_dotenv()
|
12 |
Groq_Token = os.environ["GROQ_API_KEY"]
|
13 |
+
models = {"mixtral": "mixtral-8x7b-32768", "llama": "llama2-70b-4096", "gemma": "gemma-7b-it", "gemini-pro": "gemini-pro"}
|
14 |
|
15 |
hf_token = os.getenv("HF_READ")
|
16 |
+
gemini_token = os.getenv("GEMINI_TOKEN")
|
17 |
|
18 |
def preprocess_and_load_df(path: str) -> pd.DataFrame:
|
19 |
df = pd.read_csv(path)
|
|
|
29 |
# top_k=5,
|
30 |
# )
|
31 |
# llm.client.headers = {"Authorization": f"Bearer {hf_token}"}
|
32 |
+
if name == "gemini-pro":
|
33 |
+
llm = GoogleGenerativeAI(model=model, google_api_key=gemini_token, temperature=0.1)
|
34 |
+
else:
|
35 |
+
llm = ChatGroq(model=models[name], api_key=os.getenv("GROQ_API"), temperature=0.1)
|
36 |
|
37 |
agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
|
38 |
agent.add_message(context)
|
|
|
91 |
return {"is_image": False}
|
92 |
|
93 |
def ask_question(model_name, question):
|
94 |
+
if model_name == "gemini-pro":
|
95 |
+
llm = GoogleGenerativeAI(model=model, google_api_key=os.environ.get("GOOGLE_API_KEY"), temperature=0)
|
96 |
+
else:
|
97 |
+
llm = ChatGroq(model=models[model_name], api_key=os.getenv("GROQ_API"), temperature=0.1)
|
98 |
|
99 |
df_check = pd.read_csv("Data.csv")
|
100 |
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
|
|
|
129 |
{template}
|
130 |
|
131 |
"""
|
132 |
+
if model_name == "gemini-pro":
|
133 |
+
answer = llm.invoke(query)
|
134 |
+
else:
|
135 |
+
answer = llm.invoke(query).content
|
136 |
code = f"""
|
137 |
{template.split("```python")[1].split("```")[0]}
|
138 |
+
{answer.split("```python")[1].split("```")[0]}
|
139 |
"""
|
140 |
# update variable `answer` when code is executed
|
141 |
exec(code)
|