Sivan Ratson commited on
Commit
d61d72d
1 Parent(s): 8d38779

use API provided by the user.

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. agent_workflow.py +4 -4
  3. app.py +101 -39
  4. llm_providers.py +10 -13
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ index
3
+ __pycache__
agent_workflow.py CHANGED
@@ -9,17 +9,17 @@ from tantivy_search_agent import TantivySearchAgent
9
  load_dotenv()
10
 
11
  class SearchAgent:
12
- def __init__(self, tantivy_agent: TantivySearchAgent, provider_name: str = "Gemini"):
13
  """Initialize the search agent with Tantivy agent and LLM client"""
14
  self.tantivy_agent = tantivy_agent
15
  self.logger = logging.getLogger(__name__)
16
 
17
- # Initialize LLM provider
18
- self.llm_provider = LLMProvider()
19
  self.llm = None
20
  self.set_provider(provider_name)
21
 
22
- self.min_confidence_threshold = 0.5
23
 
24
  def set_provider(self, provider_name: str) -> None:
25
  self.llm = self.llm_provider.get_provider(provider_name)
 
9
  load_dotenv()
10
 
11
  class SearchAgent:
12
+ def __init__(self, tantivy_agent: TantivySearchAgent, provider_name: str = "Gemini", api_keys: Dict[str, str] = None):
13
  """Initialize the search agent with Tantivy agent and LLM client"""
14
  self.tantivy_agent = tantivy_agent
15
  self.logger = logging.getLogger(__name__)
16
 
17
+ # Initialize LLM provider with API keys
18
+ self.llm_provider = LLMProvider(api_keys)
19
  self.llm = None
20
  self.set_provider(provider_name)
21
 
22
+ self.min_confidence_threshold = 0.7
23
 
24
  def set_provider(self, provider_name: str) -> None:
25
  self.llm = self.llm_provider.get_provider(provider_name)
app.py CHANGED
@@ -16,6 +16,7 @@ class SearchAgentUI:
16
  self.index_path ="./index" # os.getenv("INDEX_PATH", "./index")
17
  # Google Drive folder ID for the index
18
  self.gdrive_index_id = os.getenv("GDRIVE_INDEX_ID", "1lpbBCPimwcNfC0VZOlQueA4SHNGIp5_t")
 
19
  def download_index_from_gdrive(self) -> bool:
20
  """Download index folder from Google Drive"""
21
  try:
@@ -30,20 +31,18 @@ class SearchAgentUI:
30
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
31
  zip_ref.extractall(".")
32
 
33
- # Remove the zip file
34
- os.remove(zip_path)
35
- return True
36
  except Exception as e:
37
  st.error(f"Failed to download index: {str(e)}")
38
  return False
39
 
40
- def get_available_providers(self) -> List[str]:
41
  """Get available providers without creating a SearchAgent instance"""
42
  temp_tantivy = TantivySearchAgent(self.index_path)
43
- temp_agent = SearchAgent(temp_tantivy)
44
  return temp_agent.get_available_providers()
45
 
46
- def initialize_system(self):
47
  try:
48
  # Check if index folder exists
49
  if not os.path.exists(self.index_path):
@@ -54,10 +53,14 @@ class SearchAgentUI:
54
 
55
  self.tantivy_agent = TantivySearchAgent(self.index_path)
56
  if self.tantivy_agent.validate_index():
57
- available_providers = self.get_available_providers()
 
 
 
58
  self.agent = SearchAgent(
59
  self.tantivy_agent,
60
- provider_name=st.session_state.get('provider', available_providers[0])
 
61
  )
62
  return True, "讛诪注专讻转 诪讜讻谞讛 诇讞讬驻讜砖", available_providers
63
  else:
@@ -69,7 +72,7 @@ class SearchAgentUI:
69
  st.set_page_config(
70
  page_title="讗讬转讜专讬讗",
71
  layout="wide",
72
- initial_sidebar_state="collapsed"
73
  )
74
 
75
  # Enhanced RTL support and styling
@@ -108,50 +111,109 @@ class SearchAgentUI:
108
  margin: 5px 0;
109
  background-color: white;
110
  }
 
 
 
111
  </style>
112
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
113
 
114
- # Initialize system
115
- success, status_msg, available_providers = self.initialize_system()
116
-
117
- # Header layout
118
- col1, col2, col3 = st.columns([2,1,1])
119
 
120
- with col1:
121
- if success:
122
- st.success(status_msg)
123
- else:
124
- st.error(status_msg)
125
 
126
- with col2:
127
- if 'provider' not in st.session_state:
128
- st.session_state.provider = available_providers[0] if available_providers else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if available_providers:
 
 
 
131
  provider = st.selectbox(
132
  "住驻拽 讘讬谞讛 诪诇讗讻讜转讬转",
133
  options=available_providers,
134
- key='provider'
 
135
  )
136
  if self.agent:
137
  self.agent.set_provider(provider)
138
 
139
- with col3:
140
- col3_1, col3_2 = st.columns(2)
141
- with col3_1:
142
- max_iterations = st.number_input(
143
- "诪住驻专 谞住讬讜谞讜转 诪拽住讬诪诇讬",
144
- min_value=1,
145
- value=3,
146
- key='max_iterations'
147
- )
148
- with col3_2:
149
- results_per_search = st.number_input(
150
- "转讜爪讗讜转 诇讻诇 讞讬驻讜砖",
151
- min_value=1,
152
- value=5,
153
- key='results_per_search'
154
- )
 
 
 
 
 
155
 
156
  # Search input
157
  query = st.text_input(
 
16
  self.index_path ="./index" # os.getenv("INDEX_PATH", "./index")
17
  # Google Drive folder ID for the index
18
  self.gdrive_index_id = os.getenv("GDRIVE_INDEX_ID", "1lpbBCPimwcNfC0VZOlQueA4SHNGIp5_t")
19
+
20
  def download_index_from_gdrive(self) -> bool:
21
  """Download index folder from Google Drive"""
22
  try:
 
31
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
32
  zip_ref.extractall(".")
33
 
34
+
 
 
35
  except Exception as e:
36
  st.error(f"Failed to download index: {str(e)}")
37
  return False
38
 
39
+ def get_available_providers(self, api_keys: dict) -> List[str]:
40
  """Get available providers without creating a SearchAgent instance"""
41
  temp_tantivy = TantivySearchAgent(self.index_path)
42
+ temp_agent = SearchAgent(temp_tantivy, api_keys=api_keys)
43
  return temp_agent.get_available_providers()
44
 
45
+ def initialize_system(self, api_keys: dict):
46
  try:
47
  # Check if index folder exists
48
  if not os.path.exists(self.index_path):
 
53
 
54
  self.tantivy_agent = TantivySearchAgent(self.index_path)
55
  if self.tantivy_agent.validate_index():
56
+ available_providers = self.get_available_providers(api_keys)
57
+ if not available_providers:
58
+ return False, "砖讙讬讗讛: 诇讗 谞诪爪讗讜 住驻拽讬 AI 讝诪讬谞讬诐. 讗谞讗 讛讝谉 诪驻转讞 API 讗讞讚 诇驻讞讜转.", []
59
+
60
  self.agent = SearchAgent(
61
  self.tantivy_agent,
62
+ provider_name=st.session_state.get('provider', available_providers[0]),
63
+ api_keys=api_keys
64
  )
65
  return True, "讛诪注专讻转 诪讜讻谞讛 诇讞讬驻讜砖", available_providers
66
  else:
 
72
  st.set_page_config(
73
  page_title="讗讬转讜专讬讗",
74
  layout="wide",
75
+ initial_sidebar_state="expanded"
76
  )
77
 
78
  # Enhanced RTL support and styling
 
111
  margin: 5px 0;
112
  background-color: white;
113
  }
114
+ [data-testid="stSidebar"] {
115
+ direction: rtl;
116
+ }
117
  </style>
118
  """, unsafe_allow_html=True)
119
+
120
+ st.session_state.api_keys = {
121
+ 'google': "",
122
+ 'openai': "",
123
+ 'anthropic': ""
124
+
125
+ }
126
 
 
 
 
 
 
127
 
 
 
 
 
 
128
 
129
+ # Sidebar settings
130
+ with st.sidebar:
131
+ st.title("讛讙讚专讜转")
132
+
133
+ # API Key Configuration
134
+ st.subheader("讛讙讚专转 诪驻转讞讜转 API")
135
+
136
+ # Google API Key
137
+ google_key = st.text_input(
138
+ "Google API Key",
139
+ value=st.session_state.api_keys['google'],
140
+ type="password",
141
+ key="google_key",
142
+ help="讛讝谉 讗转 诪驻转讞 讛-API 砖诇 Google Gemini "
143
+ )
144
+ st.session_state.api_keys['google'] = google_key
145
+
146
+ st.html('<small> 谞讬转谉 诇讛砖讬讙 诪驻转讞 <a href="https://aistudio.google.com/app/apikey">讻讗谉</a> </small>', )
147
+
148
+ # OpenAI API Key
149
+ openai_key = st.text_input(
150
+ "OpenAI API Key",
151
+ value=st.session_state.api_keys['openai'],
152
+ type="password",
153
+ key="openai_key",
154
+ help="讛讝谉 讗转 诪驻转讞 讛-API 砖诇 OpenAI"
155
+ )
156
+ st.session_state.api_keys['openai'] = openai_key
157
 
158
+
159
+ st.html('<small> 谞讬转谉 诇讛砖讬讙 诪驻转讞 <a href="https://platform.openai.com/account/api-keys">讻讗谉</a> </small> ', )
160
+
161
+
162
+ # Anthropic API Key
163
+ anthropic_key = st.text_input(
164
+ "Anthropic API Key",
165
+ value=st.session_state.api_keys['anthropic'],
166
+ type="password",
167
+ key="anthropic_key",
168
+ help="讛讝谉 讗转 诪驻转讞 讛-API 砖诇 Anthropic Claude"
169
+ )
170
+ st.session_state.api_keys['anthropic'] = anthropic_key
171
+
172
+
173
+ st.html('<small> 谞讬转谉 诇讛砖讬讙 诪驻转讞 <a href="https://console.anthropic.com/">讻讗谉</a> </small>', )
174
+
175
+
176
+ st.markdown("---")
177
+
178
+ # Initialize system with current API keys
179
+ success, status_msg, available_providers = self.initialize_system(st.session_state.api_keys)
180
+
181
+ # Continue with sidebar settings
182
+ with st.sidebar:
183
  if available_providers:
184
+ if 'provider' not in st.session_state or st.session_state.provider not in available_providers:
185
+ st.session_state.provider = available_providers[0] if available_providers else None
186
+
187
  provider = st.selectbox(
188
  "住驻拽 讘讬谞讛 诪诇讗讻讜转讬转",
189
  options=available_providers,
190
+ key='provider',
191
+ help="讘讞专 讗转 诪讜讚诇 讛AI 诇砖讬诪讜砖 (专拽 诪讜讚诇讬诐 注诐 诪驻转讞 API 讝诪讬谉 讬讜爪讙讜)"
192
  )
193
  if self.agent:
194
  self.agent.set_provider(provider)
195
 
196
+ max_iterations = st.number_input(
197
+ "诪住驻专 谞住讬讜谞讜转 诪拽住讬诪诇讬",
198
+ min_value=1,
199
+ value=6,
200
+ key='max_iterations'
201
+ )
202
+
203
+ results_per_search = st.number_input(
204
+ "转讜爪讗讜转 诇讻诇 讞讬驻讜砖",
205
+ min_value=1,
206
+ value=10,
207
+ key='results_per_search'
208
+ )
209
+
210
+ # Main content area
211
+ st.title("讗讬转讜专讬讗")
212
+
213
+ if success:
214
+ st.success(status_msg)
215
+ else:
216
+ st.error(status_msg)
217
 
218
  # Search input
219
  query = st.text_input(
llm_providers.py CHANGED
@@ -90,35 +90,32 @@ class GeminiProvider:
90
  raise ValueError("Unsupported prompt format")
91
 
92
  class LLMProvider:
93
- def __init__(self):
94
  self.providers: Dict[str, Any] = {}
95
- self._setup_providers()
96
 
97
- def _setup_providers(self):
98
-
99
- # Google Gemini
100
- if google_key := os.getenv('GOOGLE_API_KEY'):
101
  self.providers['Gemini'] = GeminiProvider(api_key=google_key)
102
 
103
-
104
  # Anthropic
105
- if anthropic_key := os.getenv('ANTHROPIC_API_KEY'):
 
106
  self.providers['Claude'] = ChatAnthropic(
107
  api_key=anthropic_key,
108
  model_name="claude-3-5-sonnet-20241022",
109
-
110
  )
111
 
112
  # OpenAI
113
- if openai_key := os.getenv('OPENAI_API_KEY'):
 
114
  self.providers['ChatGPT'] = ChatOpenAI(
115
  api_key=openai_key,
116
  model_name="gpt-4o-2024-11-20"
117
  )
118
 
119
-
120
-
121
-
122
  def get_available_providers(self) -> list[str]:
123
  """Return list of available provider names"""
124
  return list(self.providers.keys())
 
90
  raise ValueError("Unsupported prompt format")
91
 
92
  class LLMProvider:
93
+ def __init__(self, api_keys: Dict[str, str] = None):
94
  self.providers: Dict[str, Any] = {}
95
+ self._setup_providers(api_keys or {})
96
 
97
+ def _setup_providers(self, api_keys: Dict[str, str]):
98
+ # Google Gemini
99
+ google_key = api_keys.get('google') or os.getenv('GOOGLE_API_KEY')
100
+ if google_key:
101
  self.providers['Gemini'] = GeminiProvider(api_key=google_key)
102
 
 
103
  # Anthropic
104
+ anthropic_key = api_keys.get('anthropic') or os.getenv('ANTHROPIC_API_KEY')
105
+ if anthropic_key:
106
  self.providers['Claude'] = ChatAnthropic(
107
  api_key=anthropic_key,
108
  model_name="claude-3-5-sonnet-20241022",
 
109
  )
110
 
111
  # OpenAI
112
+ openai_key = api_keys.get('openai') or os.getenv('OPENAI_API_KEY')
113
+ if openai_key:
114
  self.providers['ChatGPT'] = ChatOpenAI(
115
  api_key=openai_key,
116
  model_name="gpt-4o-2024-11-20"
117
  )
118
 
 
 
 
119
  def get_available_providers(self) -> list[str]:
120
  """Return list of available provider names"""
121
  return list(self.providers.keys())