Sivan Ratson
commited on
Commit
โข
d61d72d
1
Parent(s):
8d38779
use API provided by the user.
Browse files- .gitignore +3 -0
- agent_workflow.py +4 -4
- app.py +101 -39
- llm_providers.py +10 -13
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
index
|
3 |
+
__pycache__
|
agent_workflow.py
CHANGED
@@ -9,17 +9,17 @@ from tantivy_search_agent import TantivySearchAgent
|
|
9 |
load_dotenv()
|
10 |
|
11 |
class SearchAgent:
|
12 |
-
def __init__(self, tantivy_agent: TantivySearchAgent, provider_name: str = "Gemini"):
|
13 |
"""Initialize the search agent with Tantivy agent and LLM client"""
|
14 |
self.tantivy_agent = tantivy_agent
|
15 |
self.logger = logging.getLogger(__name__)
|
16 |
|
17 |
-
# Initialize LLM provider
|
18 |
-
self.llm_provider = LLMProvider()
|
19 |
self.llm = None
|
20 |
self.set_provider(provider_name)
|
21 |
|
22 |
-
self.min_confidence_threshold = 0.
|
23 |
|
24 |
def set_provider(self, provider_name: str) -> None:
|
25 |
self.llm = self.llm_provider.get_provider(provider_name)
|
|
|
9 |
load_dotenv()
|
10 |
|
11 |
class SearchAgent:
|
12 |
+
def __init__(self, tantivy_agent: TantivySearchAgent, provider_name: str = "Gemini", api_keys: Dict[str, str] = None):
|
13 |
"""Initialize the search agent with Tantivy agent and LLM client"""
|
14 |
self.tantivy_agent = tantivy_agent
|
15 |
self.logger = logging.getLogger(__name__)
|
16 |
|
17 |
+
# Initialize LLM provider with API keys
|
18 |
+
self.llm_provider = LLMProvider(api_keys)
|
19 |
self.llm = None
|
20 |
self.set_provider(provider_name)
|
21 |
|
22 |
+
self.min_confidence_threshold = 0.7
|
23 |
|
24 |
def set_provider(self, provider_name: str) -> None:
|
25 |
self.llm = self.llm_provider.get_provider(provider_name)
|
app.py
CHANGED
@@ -16,6 +16,7 @@ class SearchAgentUI:
|
|
16 |
self.index_path ="./index" # os.getenv("INDEX_PATH", "./index")
|
17 |
# Google Drive folder ID for the index
|
18 |
self.gdrive_index_id = os.getenv("GDRIVE_INDEX_ID", "1lpbBCPimwcNfC0VZOlQueA4SHNGIp5_t")
|
|
|
19 |
def download_index_from_gdrive(self) -> bool:
|
20 |
"""Download index folder from Google Drive"""
|
21 |
try:
|
@@ -30,20 +31,18 @@ class SearchAgentUI:
|
|
30 |
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
31 |
zip_ref.extractall(".")
|
32 |
|
33 |
-
|
34 |
-
os.remove(zip_path)
|
35 |
-
return True
|
36 |
except Exception as e:
|
37 |
st.error(f"Failed to download index: {str(e)}")
|
38 |
return False
|
39 |
|
40 |
-
def get_available_providers(self) -> List[str]:
|
41 |
"""Get available providers without creating a SearchAgent instance"""
|
42 |
temp_tantivy = TantivySearchAgent(self.index_path)
|
43 |
-
temp_agent = SearchAgent(temp_tantivy)
|
44 |
return temp_agent.get_available_providers()
|
45 |
|
46 |
-
def initialize_system(self):
|
47 |
try:
|
48 |
# Check if index folder exists
|
49 |
if not os.path.exists(self.index_path):
|
@@ -54,10 +53,14 @@ class SearchAgentUI:
|
|
54 |
|
55 |
self.tantivy_agent = TantivySearchAgent(self.index_path)
|
56 |
if self.tantivy_agent.validate_index():
|
57 |
-
available_providers = self.get_available_providers()
|
|
|
|
|
|
|
58 |
self.agent = SearchAgent(
|
59 |
self.tantivy_agent,
|
60 |
-
provider_name=st.session_state.get('provider', available_providers[0])
|
|
|
61 |
)
|
62 |
return True, "ืืืขืจืืช ืืืื ื ืืืืคืืฉ", available_providers
|
63 |
else:
|
@@ -69,7 +72,7 @@ class SearchAgentUI:
|
|
69 |
st.set_page_config(
|
70 |
page_title="ืืืชืืจืื",
|
71 |
layout="wide",
|
72 |
-
initial_sidebar_state="
|
73 |
)
|
74 |
|
75 |
# Enhanced RTL support and styling
|
@@ -108,50 +111,109 @@ class SearchAgentUI:
|
|
108 |
margin: 5px 0;
|
109 |
background-color: white;
|
110 |
}
|
|
|
|
|
|
|
111 |
</style>
|
112 |
""", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
# Initialize system
|
115 |
-
success, status_msg, available_providers = self.initialize_system()
|
116 |
-
|
117 |
-
# Header layout
|
118 |
-
col1, col2, col3 = st.columns([2,1,1])
|
119 |
|
120 |
-
with col1:
|
121 |
-
if success:
|
122 |
-
st.success(status_msg)
|
123 |
-
else:
|
124 |
-
st.error(status_msg)
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
if available_providers:
|
|
|
|
|
|
|
131 |
provider = st.selectbox(
|
132 |
"ืกืคืง ืืื ื ืืืืืืชืืช",
|
133 |
options=available_providers,
|
134 |
-
key='provider'
|
|
|
135 |
)
|
136 |
if self.agent:
|
137 |
self.agent.set_provider(provider)
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# Search input
|
157 |
query = st.text_input(
|
|
|
16 |
self.index_path ="./index" # os.getenv("INDEX_PATH", "./index")
|
17 |
# Google Drive folder ID for the index
|
18 |
self.gdrive_index_id = os.getenv("GDRIVE_INDEX_ID", "1lpbBCPimwcNfC0VZOlQueA4SHNGIp5_t")
|
19 |
+
|
20 |
def download_index_from_gdrive(self) -> bool:
|
21 |
"""Download index folder from Google Drive"""
|
22 |
try:
|
|
|
31 |
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
32 |
zip_ref.extractall(".")
|
33 |
|
34 |
+
|
|
|
|
|
35 |
except Exception as e:
|
36 |
st.error(f"Failed to download index: {str(e)}")
|
37 |
return False
|
38 |
|
39 |
+
def get_available_providers(self, api_keys: dict) -> List[str]:
|
40 |
"""Get available providers without creating a SearchAgent instance"""
|
41 |
temp_tantivy = TantivySearchAgent(self.index_path)
|
42 |
+
temp_agent = SearchAgent(temp_tantivy, api_keys=api_keys)
|
43 |
return temp_agent.get_available_providers()
|
44 |
|
45 |
+
def initialize_system(self, api_keys: dict):
|
46 |
try:
|
47 |
# Check if index folder exists
|
48 |
if not os.path.exists(self.index_path):
|
|
|
53 |
|
54 |
self.tantivy_agent = TantivySearchAgent(self.index_path)
|
55 |
if self.tantivy_agent.validate_index():
|
56 |
+
available_providers = self.get_available_providers(api_keys)
|
57 |
+
if not available_providers:
|
58 |
+
return False, "ืฉืืืื: ืื ื ืืฆืื ืกืคืงื AI ืืืื ืื. ืื ื ืืื ืืคืชื API ืืื ืืคืืืช.", []
|
59 |
+
|
60 |
self.agent = SearchAgent(
|
61 |
self.tantivy_agent,
|
62 |
+
provider_name=st.session_state.get('provider', available_providers[0]),
|
63 |
+
api_keys=api_keys
|
64 |
)
|
65 |
return True, "ืืืขืจืืช ืืืื ื ืืืืคืืฉ", available_providers
|
66 |
else:
|
|
|
72 |
st.set_page_config(
|
73 |
page_title="ืืืชืืจืื",
|
74 |
layout="wide",
|
75 |
+
initial_sidebar_state="expanded"
|
76 |
)
|
77 |
|
78 |
# Enhanced RTL support and styling
|
|
|
111 |
margin: 5px 0;
|
112 |
background-color: white;
|
113 |
}
|
114 |
+
[data-testid="stSidebar"] {
|
115 |
+
direction: rtl;
|
116 |
+
}
|
117 |
</style>
|
118 |
""", unsafe_allow_html=True)
|
119 |
+
|
120 |
+
st.session_state.api_keys = {
|
121 |
+
'google': "",
|
122 |
+
'openai': "",
|
123 |
+
'anthropic': ""
|
124 |
+
|
125 |
+
}
|
126 |
|
|
|
|
|
|
|
|
|
|
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
# Sidebar settings
|
130 |
+
with st.sidebar:
|
131 |
+
st.title("ืืืืจืืช")
|
132 |
+
|
133 |
+
# API Key Configuration
|
134 |
+
st.subheader("ืืืืจืช ืืคืชืืืช API")
|
135 |
+
|
136 |
+
# Google API Key
|
137 |
+
google_key = st.text_input(
|
138 |
+
"Google API Key",
|
139 |
+
value=st.session_state.api_keys['google'],
|
140 |
+
type="password",
|
141 |
+
key="google_key",
|
142 |
+
help="ืืื ืืช ืืคืชื ื-API ืฉื Google Gemini "
|
143 |
+
)
|
144 |
+
st.session_state.api_keys['google'] = google_key
|
145 |
+
|
146 |
+
st.html('<small> ื ืืชื ืืืฉืื ืืคืชื <a href="https://aistudio.google.com/app/apikey">ืืื</a> </small>', )
|
147 |
+
|
148 |
+
# OpenAI API Key
|
149 |
+
openai_key = st.text_input(
|
150 |
+
"OpenAI API Key",
|
151 |
+
value=st.session_state.api_keys['openai'],
|
152 |
+
type="password",
|
153 |
+
key="openai_key",
|
154 |
+
help="ืืื ืืช ืืคืชื ื-API ืฉื OpenAI"
|
155 |
+
)
|
156 |
+
st.session_state.api_keys['openai'] = openai_key
|
157 |
|
158 |
+
|
159 |
+
st.html('<small> ื ืืชื ืืืฉืื ืืคืชื <a href="https://platform.openai.com/account/api-keys">ืืื</a> </small> ', )
|
160 |
+
|
161 |
+
|
162 |
+
# Anthropic API Key
|
163 |
+
anthropic_key = st.text_input(
|
164 |
+
"Anthropic API Key",
|
165 |
+
value=st.session_state.api_keys['anthropic'],
|
166 |
+
type="password",
|
167 |
+
key="anthropic_key",
|
168 |
+
help="ืืื ืืช ืืคืชื ื-API ืฉื Anthropic Claude"
|
169 |
+
)
|
170 |
+
st.session_state.api_keys['anthropic'] = anthropic_key
|
171 |
+
|
172 |
+
|
173 |
+
st.html('<small> ื ืืชื ืืืฉืื ืืคืชื <a href="https://console.anthropic.com/">ืืื</a> </small>', )
|
174 |
+
|
175 |
+
|
176 |
+
st.markdown("---")
|
177 |
+
|
178 |
+
# Initialize system with current API keys
|
179 |
+
success, status_msg, available_providers = self.initialize_system(st.session_state.api_keys)
|
180 |
+
|
181 |
+
# Continue with sidebar settings
|
182 |
+
with st.sidebar:
|
183 |
if available_providers:
|
184 |
+
if 'provider' not in st.session_state or st.session_state.provider not in available_providers:
|
185 |
+
st.session_state.provider = available_providers[0] if available_providers else None
|
186 |
+
|
187 |
provider = st.selectbox(
|
188 |
"ืกืคืง ืืื ื ืืืืืืชืืช",
|
189 |
options=available_providers,
|
190 |
+
key='provider',
|
191 |
+
help="ืืืจ ืืช ืืืื ืAI ืืฉืืืืฉ (ืจืง ืืืืืื ืขื ืืคืชื API ืืืื ืืืฆืื)"
|
192 |
)
|
193 |
if self.agent:
|
194 |
self.agent.set_provider(provider)
|
195 |
|
196 |
+
max_iterations = st.number_input(
|
197 |
+
"ืืกืคืจ ื ืกืืื ืืช ืืงืกืืืื",
|
198 |
+
min_value=1,
|
199 |
+
value=6,
|
200 |
+
key='max_iterations'
|
201 |
+
)
|
202 |
+
|
203 |
+
results_per_search = st.number_input(
|
204 |
+
"ืชืืฆืืืช ืืื ืืืคืืฉ",
|
205 |
+
min_value=1,
|
206 |
+
value=10,
|
207 |
+
key='results_per_search'
|
208 |
+
)
|
209 |
+
|
210 |
+
# Main content area
|
211 |
+
st.title("ืืืชืืจืื")
|
212 |
+
|
213 |
+
if success:
|
214 |
+
st.success(status_msg)
|
215 |
+
else:
|
216 |
+
st.error(status_msg)
|
217 |
|
218 |
# Search input
|
219 |
query = st.text_input(
|
llm_providers.py
CHANGED
@@ -90,35 +90,32 @@ class GeminiProvider:
|
|
90 |
raise ValueError("Unsupported prompt format")
|
91 |
|
92 |
class LLMProvider:
|
93 |
-
def __init__(self):
|
94 |
self.providers: Dict[str, Any] = {}
|
95 |
-
self._setup_providers()
|
96 |
|
97 |
-
def _setup_providers(self):
|
98 |
-
|
99 |
-
|
100 |
-
if google_key
|
101 |
self.providers['Gemini'] = GeminiProvider(api_key=google_key)
|
102 |
|
103 |
-
|
104 |
# Anthropic
|
105 |
-
|
|
|
106 |
self.providers['Claude'] = ChatAnthropic(
|
107 |
api_key=anthropic_key,
|
108 |
model_name="claude-3-5-sonnet-20241022",
|
109 |
-
|
110 |
)
|
111 |
|
112 |
# OpenAI
|
113 |
-
|
|
|
114 |
self.providers['ChatGPT'] = ChatOpenAI(
|
115 |
api_key=openai_key,
|
116 |
model_name="gpt-4o-2024-11-20"
|
117 |
)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
def get_available_providers(self) -> list[str]:
|
123 |
"""Return list of available provider names"""
|
124 |
return list(self.providers.keys())
|
|
|
90 |
raise ValueError("Unsupported prompt format")
|
91 |
|
92 |
class LLMProvider:
|
93 |
+
def __init__(self, api_keys: Dict[str, str] = None):
|
94 |
self.providers: Dict[str, Any] = {}
|
95 |
+
self._setup_providers(api_keys or {})
|
96 |
|
97 |
+
def _setup_providers(self, api_keys: Dict[str, str]):
|
98 |
+
# Google Gemini
|
99 |
+
google_key = api_keys.get('google') or os.getenv('GOOGLE_API_KEY')
|
100 |
+
if google_key:
|
101 |
self.providers['Gemini'] = GeminiProvider(api_key=google_key)
|
102 |
|
|
|
103 |
# Anthropic
|
104 |
+
anthropic_key = api_keys.get('anthropic') or os.getenv('ANTHROPIC_API_KEY')
|
105 |
+
if anthropic_key:
|
106 |
self.providers['Claude'] = ChatAnthropic(
|
107 |
api_key=anthropic_key,
|
108 |
model_name="claude-3-5-sonnet-20241022",
|
|
|
109 |
)
|
110 |
|
111 |
# OpenAI
|
112 |
+
openai_key = api_keys.get('openai') or os.getenv('OPENAI_API_KEY')
|
113 |
+
if openai_key:
|
114 |
self.providers['ChatGPT'] = ChatOpenAI(
|
115 |
api_key=openai_key,
|
116 |
model_name="gpt-4o-2024-11-20"
|
117 |
)
|
118 |
|
|
|
|
|
|
|
119 |
def get_available_providers(self) -> list[str]:
|
120 |
"""Return list of available provider names"""
|
121 |
return list(self.providers.keys())
|