Spaces:
Running
Running
XinGuan2000
commited on
Commit
·
29a316b
1
Parent(s):
40ea049
init
Browse files- app.py +479 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from saged import Pipeline
|
2 |
+
from tqdm import tqdm
|
3 |
+
from pathlib import Path
|
4 |
+
from saged import SAGEDData as dt
|
5 |
+
import streamlit as st
|
6 |
+
import json
|
7 |
+
import http.client
|
8 |
+
from openai import AzureOpenAI
|
9 |
+
import ollama
|
10 |
+
import time # Use time.sleep to simulate processing steps
|
11 |
+
import logging
|
12 |
+
from io import StringIO
|
13 |
+
import sys
|
14 |
+
|
15 |
+
# Create a custom logging handler to capture log messages
|
16 |
+
class StreamlitLogHandler(logging.Handler):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
self.log_capture_string = StringIO()
|
20 |
+
|
21 |
+
def emit(self, record):
|
22 |
+
# Write each log message to the StringIO buffer
|
23 |
+
message = self.format(record)
|
24 |
+
self.log_capture_string.write(message + "\n")
|
25 |
+
|
26 |
+
def get_logs(self):
|
27 |
+
# Return the log contents
|
28 |
+
return self.log_capture_string.getvalue()
|
29 |
+
|
30 |
+
def clear_logs(self):
|
31 |
+
# Clear the log buffer
|
32 |
+
self.log_capture_string.truncate(0)
|
33 |
+
self.log_capture_string.seek(0)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
# Define ContentFormatter class
|
38 |
+
class ContentFormatter:
|
39 |
+
@staticmethod
|
40 |
+
def chat_completions(text, settings_params):
|
41 |
+
message = [
|
42 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
43 |
+
{"role": "user", "content": text}
|
44 |
+
]
|
45 |
+
data = {"messages": message, **settings_params}
|
46 |
+
return json.dumps(data)
|
47 |
+
|
48 |
+
|
49 |
+
# Define OllamaModel (For local Ollama interaction)
|
50 |
+
class OllamaModel:
|
51 |
+
def __init__(self, base_model='llama3', system_prompt='You are a helpful assistant', model_name='llama3o',
|
52 |
+
**kwargs):
|
53 |
+
self.base_model = base_model
|
54 |
+
self.model_name = model_name
|
55 |
+
self.model_create(model_name, system_prompt, base_model, **kwargs)
|
56 |
+
|
57 |
+
def model_create(self, model_name, system_prompt, base_model, **kwargs):
|
58 |
+
modelfile = f'FROM {base_model}\nSYSTEM {system_prompt}\n'
|
59 |
+
if kwargs:
|
60 |
+
for key, value in kwargs.items():
|
61 |
+
modelfile += f'PARAMETER {key.lower()} {value}\n'
|
62 |
+
ollama.create(model=model_name, modelfile=modelfile)
|
63 |
+
|
64 |
+
def invoke(self, prompt):
|
65 |
+
answer = ollama.generate(model=self.model_name, prompt=prompt)
|
66 |
+
return answer['response']
|
67 |
+
|
68 |
+
|
69 |
+
# Define GPTAgent (For OpenAI GPT models)
|
70 |
+
class GPTAgent:
|
71 |
+
def __init__(self, model_name, azure_key, azure_version, azure_endpoint, deployment_name):
|
72 |
+
self.client = AzureOpenAI(
|
73 |
+
api_key=azure_key,
|
74 |
+
api_version=azure_version,
|
75 |
+
azure_endpoint=azure_endpoint
|
76 |
+
)
|
77 |
+
self.deployment_name = deployment_name
|
78 |
+
|
79 |
+
def invoke(self, prompt, settings_params=None):
|
80 |
+
if not settings_params:
|
81 |
+
settings_params = {}
|
82 |
+
formatted_input = ContentFormatter.chat_completions(prompt, settings_params)
|
83 |
+
response = self.client.chat.completions.create(
|
84 |
+
model=self.deployment_name,
|
85 |
+
messages=json.loads(formatted_input)['messages'],
|
86 |
+
**settings_params
|
87 |
+
)
|
88 |
+
return response.choices[0].message.content
|
89 |
+
|
90 |
+
|
91 |
+
# Define AzureAgent (For Azure OpenAI models)
|
92 |
+
class AzureAgent:
|
93 |
+
def __init__(self, model_name, azure_uri, azure_api_key):
|
94 |
+
self.azure_uri = azure_uri
|
95 |
+
self.headers = {
|
96 |
+
'Authorization': f"Bearer {azure_api_key}",
|
97 |
+
'Content-Type': 'application/json'
|
98 |
+
}
|
99 |
+
self.chat_formatter = ContentFormatter
|
100 |
+
|
101 |
+
def invoke(self, prompt, settings_params=None):
|
102 |
+
if not settings_params:
|
103 |
+
settings_params = {}
|
104 |
+
body = self.chat_formatter.chat_completions(prompt, {**settings_params})
|
105 |
+
conn = http.client.HTTPSConnection(self.azure_uri)
|
106 |
+
conn.request("POST", '/v1/chat/completions', body=body, headers=self.headers)
|
107 |
+
response = conn.getresponse()
|
108 |
+
data = response.read()
|
109 |
+
conn.close()
|
110 |
+
decoded_data = data.decode("utf-8")
|
111 |
+
parsed_data = json.loads(decoded_data)
|
112 |
+
content = parsed_data["choices"][0]["message"]["content"]
|
113 |
+
return content
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
# Renew Source Finder Button
|
118 |
+
def renew_source_finder(domain, concept_list):
|
119 |
+
if 'generated_synthetic_files' in st.session_state:
|
120 |
+
del st.session_state['generated_synthetic_files']
|
121 |
+
if not domain or not concept_list:
|
122 |
+
st.error("Please fill in all the required fields before proceeding.")
|
123 |
+
else:
|
124 |
+
with st.spinner("Renewing source info files..."):
|
125 |
+
base_path = Path('data/customized/source_finder/')
|
126 |
+
for concept in concept_list:
|
127 |
+
file_path = base_path / f'{domain}_{concept}_source_finder.json'
|
128 |
+
if file_path.exists():
|
129 |
+
try:
|
130 |
+
file_path.unlink() # Delete the file
|
131 |
+
st.info(f"Deleted source info file: {file_path}")
|
132 |
+
except Exception as e:
|
133 |
+
st.error(f"An error occurred while deleting the file {file_path}: {e}")
|
134 |
+
st.success("Source info files renewal completed!")
|
135 |
+
|
136 |
+
|
137 |
+
def create_source_finder(domain, concept):
|
138 |
+
source_specification_item = f"data/customized/local_files/{domain}/{concept}.txt"
|
139 |
+
if not Path(source_specification_item).exists():
|
140 |
+
st.warning(f"Local file does not exist: {source_specification_item}")
|
141 |
+
instance = dt.create_data(domain, concept, 'source_finder')
|
142 |
+
instance.data[0]['keywords'] = {concept: dt.default_keyword_metadata.copy()}
|
143 |
+
category_shared_source_item = dt.default_source_item.copy()
|
144 |
+
category_shared_source_item['source_type'] = "local_paths"
|
145 |
+
category_shared_source_item['source_specification'] = [source_specification_item]
|
146 |
+
instance.data[0]['category_shared_source'] = [category_shared_source_item]
|
147 |
+
return instance.data.copy()
|
148 |
+
|
149 |
+
|
150 |
+
def check_and_create_source_files(domain, concept_list):
|
151 |
+
"""
|
152 |
+
Checks if the required source finder files exist for each concept in the domain.
|
153 |
+
If a file does not exist or is invalid, it creates an empty JSON file for that concept.
|
154 |
+
"""
|
155 |
+
base_path = Path('data/customized/source_finder/')
|
156 |
+
base_path.mkdir(parents=True, exist_ok=True)
|
157 |
+
for concept in concept_list:
|
158 |
+
file_path = base_path / f'{domain}_{concept}_source_finder.json'
|
159 |
+
if not file_path.exists():
|
160 |
+
# Create a new source finder file using create_source_finder
|
161 |
+
data = create_source_finder(domain, concept)
|
162 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
163 |
+
json.dump(data, f, indent=4)
|
164 |
+
st.info(f"Created missing source finder file: {file_path}")
|
165 |
+
else:
|
166 |
+
# Attempt to load the file to verify its validity
|
167 |
+
instance = dt.load_file(domain, concept, 'source_finder', file_path)
|
168 |
+
if instance is None:
|
169 |
+
# If loading fails, create a new valid file
|
170 |
+
data = create_source_finder(domain, concept)
|
171 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
172 |
+
json.dump(data, f, indent=4)
|
173 |
+
st.info(f"Recreated invalid source finder file: {file_path}")
|
174 |
+
|
175 |
+
|
176 |
+
def clean_spaces(data):
|
177 |
+
"""
|
178 |
+
Removes trailing or leading spaces from a string or from each element in a list.
|
179 |
+
"""
|
180 |
+
if isinstance(data, str):
|
181 |
+
return data.strip()
|
182 |
+
elif isinstance(data, list):
|
183 |
+
return [item.strip() if isinstance(item, str) else item for item in data]
|
184 |
+
else:
|
185 |
+
raise TypeError("Input should be either a string or a list of strings")
|
186 |
+
|
187 |
+
|
188 |
+
def create_replacement_dict(concept_list, replacer):
|
189 |
+
replacement = {}
|
190 |
+
for concept in concept_list:
|
191 |
+
replacement[concept] = {}
|
192 |
+
for company in replacer:
|
193 |
+
replacement[concept][company] = {concept: company}
|
194 |
+
return replacement
|
195 |
+
|
196 |
+
|
197 |
+
# Title of the app
|
198 |
+
st.title("SAGED-bias Benchmark-Building Demo")
|
199 |
+
|
200 |
+
# Initialize session state variables
|
201 |
+
if 'domain' not in st.session_state:
|
202 |
+
st.session_state['domain'] = None
|
203 |
+
if 'concept_list' not in st.session_state:
|
204 |
+
st.session_state['concept_list'] = None
|
205 |
+
if 'gpt_model' not in st.session_state:
|
206 |
+
st.session_state['gpt_model'] = None
|
207 |
+
if 'azure_model' not in st.session_state:
|
208 |
+
st.session_state['azure_model'] = None
|
209 |
+
if 'ollama_model' not in st.session_state:
|
210 |
+
st.session_state['ollama_model'] = None
|
211 |
+
|
212 |
+
# Sidebar: Model Selection
|
213 |
+
with st.sidebar:
|
214 |
+
st.header("Model Configuration")
|
215 |
+
|
216 |
+
# Selection of which model to use
|
217 |
+
model_selection = st.radio("Select Model Type", ['GPT-Azure', 'Azure', 'Ollama'])
|
218 |
+
|
219 |
+
# Collapsible Additional Configuration Section
|
220 |
+
with st.expander("Model Configuration"):
|
221 |
+
if model_selection == 'Ollama':
|
222 |
+
# Ollama Configuration
|
223 |
+
ollama_deployment_name = st.text_input("Enter Ollama Model Deployment Name", placeholder="e.g., llama3")
|
224 |
+
ollama_system_prompt = st.text_input("Enter System Prompt for Ollama",
|
225 |
+
placeholder="e.g., You are a helpful assistant.")
|
226 |
+
|
227 |
+
if ollama_deployment_name and ollama_system_prompt:
|
228 |
+
confirm_ollama = st.button("Confirm Ollama Configuration")
|
229 |
+
if confirm_ollama:
|
230 |
+
st.session_state['ollama_model'] = OllamaModel(
|
231 |
+
model_name=ollama_deployment_name,
|
232 |
+
system_prompt=ollama_system_prompt
|
233 |
+
)
|
234 |
+
st.success("Ollama model configured successfully.")
|
235 |
+
else:
|
236 |
+
st.warning("Please provide both Ollama deployment name and system prompt.")
|
237 |
+
|
238 |
+
elif model_selection == 'GPT-Azure' or model_selection == 'Azure':
|
239 |
+
# GPT / Azure Configuration
|
240 |
+
gpt_azure_endpoint = st.text_input("Enter Azure Endpoint URL",
|
241 |
+
placeholder="e.g., https://your-resource-name.openai.azure.com/")
|
242 |
+
gpt_azure_api_key = st.text_input("Enter Azure API Key", type="password")
|
243 |
+
gpt_azure_model_name = st.text_input("Enter Azure Model Name", placeholder="e.g., GPT-3.5-turbo")
|
244 |
+
gpt_azure_deployment_name = st.text_input("Enter Azure Deployment Name",
|
245 |
+
placeholder="e.g., gpt-3-5-deployment")
|
246 |
+
|
247 |
+
if gpt_azure_endpoint and gpt_azure_api_key and gpt_azure_model_name and gpt_azure_deployment_name:
|
248 |
+
confirm_gpt_azure = st.button("Confirm GPT/Azure Configuration")
|
249 |
+
if confirm_gpt_azure:
|
250 |
+
if model_selection == 'GPT-Azure':
|
251 |
+
st.session_state['gpt_model'] = GPTAgent(
|
252 |
+
model_name=gpt_azure_model_name,
|
253 |
+
azure_key=gpt_azure_api_key,
|
254 |
+
azure_version='2023-05-15', # Update if necessary
|
255 |
+
azure_endpoint=gpt_azure_endpoint,
|
256 |
+
deployment_name=gpt_azure_deployment_name
|
257 |
+
)
|
258 |
+
st.success("GPT model configured successfully.")
|
259 |
+
elif model_selection == 'Azure':
|
260 |
+
st.session_state['azure_model'] = AzureAgent(
|
261 |
+
model_name=gpt_azure_model_name,
|
262 |
+
azure_uri=gpt_azure_endpoint,
|
263 |
+
azure_api_key=gpt_azure_api_key
|
264 |
+
)
|
265 |
+
st.success("Azure model configured successfully.")
|
266 |
+
else:
|
267 |
+
st.warning("Please provide all fields for GPT/Azure configuration.")
|
268 |
+
|
269 |
+
# Main interaction based on configured model
|
270 |
+
if st.session_state.get('ollama_model'):
|
271 |
+
model = st.session_state['ollama_model']
|
272 |
+
elif st.session_state.get('gpt_model'):
|
273 |
+
model = st.session_state['gpt_model']
|
274 |
+
elif st.session_state.get('azure_model'):
|
275 |
+
model = st.session_state['azure_model']
|
276 |
+
else:
|
277 |
+
model = None
|
278 |
+
|
279 |
+
# User input: Domain and Concepts
|
280 |
+
with st.form(key='domain_concept_form'):
|
281 |
+
domain = clean_spaces(
|
282 |
+
st.text_input("Enter the domain: (e.g., Stocks, Education)", placeholder="Enter domain here..."))
|
283 |
+
|
284 |
+
# User input: Concepts
|
285 |
+
concept_text = st.text_area("Enter the concepts (separated by commas):",
|
286 |
+
placeholder="e.g., excel-stock, ok-stock, bad-stock")
|
287 |
+
concept_list = clean_spaces(concept_text.split(','))
|
288 |
+
|
289 |
+
submit_button = st.form_submit_button(label='Confirm Domain and Concepts')
|
290 |
+
|
291 |
+
if submit_button:
|
292 |
+
if not domain:
|
293 |
+
st.warning("Please enter a domain.")
|
294 |
+
elif not concept_list or concept_text.strip() == "":
|
295 |
+
st.warning("Please enter at least one concept.")
|
296 |
+
else:
|
297 |
+
st.session_state['domain'] = domain
|
298 |
+
st.session_state['concept_list'] = concept_list
|
299 |
+
st.success("Domain and concepts confirmed.")
|
300 |
+
|
301 |
+
# Display further options only after domain and concepts are confirmed
|
302 |
+
if st.session_state['domain'] and st.session_state['concept_list']:
|
303 |
+
with st.expander("Additional Options"):
|
304 |
+
# User input: Method
|
305 |
+
scraper_method = st.radio("Select the scraper method:", (('wiki', 'local_files', 'synthetic_files')))
|
306 |
+
|
307 |
+
# Initiate the source_finder_requirement and keyword_finder_requirement if 'wiki' is selected
|
308 |
+
if scraper_method == 'wiki':
|
309 |
+
st.session_state['keyword_finder_requirement'] = True
|
310 |
+
st.session_state['source_finder_requirement'] = True
|
311 |
+
st.session_state['check_source_finder'] = False
|
312 |
+
|
313 |
+
# File upload for each concept if 'local_files' is selected
|
314 |
+
if scraper_method == 'local_files':
|
315 |
+
uploaded_files = {}
|
316 |
+
st.session_state['keyword_finder_requirement'] = False
|
317 |
+
st.session_state['source_finder_requirement'] = False
|
318 |
+
st.session_state['check_source_finder'] = True
|
319 |
+
for concept in st.session_state['concept_list']:
|
320 |
+
uploaded_file = st.file_uploader(f"Upload file for concept '{concept}':", type=['txt'],
|
321 |
+
key=f"file_{concept}")
|
322 |
+
if uploaded_file:
|
323 |
+
uploaded_files[concept] = uploaded_file
|
324 |
+
# Save uploaded file
|
325 |
+
save_path = Path(f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt")
|
326 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
327 |
+
with open(save_path, 'wb') as f:
|
328 |
+
f.write(uploaded_file.getbuffer())
|
329 |
+
st.success(f"File for concept '{concept}' saved successfully.")
|
330 |
+
|
331 |
+
# Generate synthetic files if 'synthetic_files' is selected
|
332 |
+
if scraper_method == 'synthetic_files':
|
333 |
+
scraper_method = 'local_files'
|
334 |
+
st.session_state['keyword_finder_requirement'] = False
|
335 |
+
st.session_state['source_finder_requirement'] = False
|
336 |
+
st.session_state['check_source_finder'] = True
|
337 |
+
if 'generated_synthetic_files' not in st.session_state:
|
338 |
+
st.session_state['generated_synthetic_files'] = set()
|
339 |
+
|
340 |
+
prompt_inputs = {}
|
341 |
+
for concept in st.session_state['concept_list']:
|
342 |
+
if concept not in st.session_state['generated_synthetic_files']:
|
343 |
+
prompt_inputs[concept] = st.text_input(
|
344 |
+
f"Enter the prompt for concept '{concept}':",
|
345 |
+
value=f"Write a long article introducing the {concept} in the {st.session_state['domain']}. Use the {concept} as much as possible.",
|
346 |
+
key=f"prompt_{concept}"
|
347 |
+
)
|
348 |
+
|
349 |
+
if st.button("Generate Synthetic Files for All Concepts"):
|
350 |
+
if model:
|
351 |
+
for concept, prompt in prompt_inputs.items():
|
352 |
+
if prompt:
|
353 |
+
with st.spinner(f"Generating content for concept '{concept}'..."):
|
354 |
+
synthetic_content = model.invoke(prompt)
|
355 |
+
save_path = Path(
|
356 |
+
f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt")
|
357 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
358 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
359 |
+
f.write(synthetic_content)
|
360 |
+
st.session_state['generated_synthetic_files'].add(concept)
|
361 |
+
st.success(f"Synthetic file for concept '{concept}' created successfully.")
|
362 |
+
else:
|
363 |
+
st.warning("Please configure a model to generate synthetic files.")
|
364 |
+
|
365 |
+
# User input: Prompt Method
|
366 |
+
prompt_method = st.radio("Select the prompt method:", ('split_sentences', 'questions'), index = 0)
|
367 |
+
|
368 |
+
# User input: Max Benchmark Length
|
369 |
+
max_benchmark_length = st.slider("Select the maximum prompts per concepts:", 1, 199, 10)
|
370 |
+
|
371 |
+
# User input: Branching
|
372 |
+
branching = st.radio("Enable branching:", ('Yes', 'No'), index=1)
|
373 |
+
branching_enabled = True if branching == 'Yes' else False
|
374 |
+
|
375 |
+
# User input: Replacer (only if branching is enabled)
|
376 |
+
replacer = []
|
377 |
+
replacement = {}
|
378 |
+
if branching_enabled:
|
379 |
+
replacer_text = st.text_area("Enter the replacer list (list of strings, separated by commas):",
|
380 |
+
placeholder="e.g., Company A, Company B")
|
381 |
+
replacer = clean_spaces(replacer_text.split(','))
|
382 |
+
replacement = create_replacement_dict(st.session_state['concept_list'], replacer)
|
383 |
+
|
384 |
+
# Configuration
|
385 |
+
concept_specified_config = {
|
386 |
+
x: {'keyword_finder': {'manual_keywords': [x]}} for x in st.session_state['concept_list']
|
387 |
+
}
|
388 |
+
concept_configuration = {
|
389 |
+
'keyword_finder': {
|
390 |
+
'require': st.session_state['keyword_finder_requirement'],
|
391 |
+
'keyword_number': 1,
|
392 |
+
},
|
393 |
+
'source_finder': {
|
394 |
+
'require': st.session_state['source_finder_requirement'],
|
395 |
+
'scrap_number': 10,
|
396 |
+
'method': scraper_method,
|
397 |
+
},
|
398 |
+
'scraper': {
|
399 |
+
'require': True,
|
400 |
+
'method': scraper_method,
|
401 |
+
},
|
402 |
+
'prompt_maker': {
|
403 |
+
'method': prompt_method,
|
404 |
+
'generation_function': model.invoke if model else None,
|
405 |
+
'max_benchmark_length': max_benchmark_length,
|
406 |
+
},
|
407 |
+
}
|
408 |
+
domain_configuration = {
|
409 |
+
'categories': st.session_state['concept_list'],
|
410 |
+
'branching': branching_enabled,
|
411 |
+
'branching_config': {
|
412 |
+
'generation_function': model.invoke if model else None,
|
413 |
+
'keyword_reference': st.session_state['concept_list'],
|
414 |
+
'replacement_descriptor_require': False,
|
415 |
+
'replacement_description': replacement,
|
416 |
+
'branching_pairs': 'not all',
|
417 |
+
'direction': 'not both',
|
418 |
+
},
|
419 |
+
'shared_config': concept_configuration,
|
420 |
+
'category_specified_config': concept_specified_config
|
421 |
+
}
|
422 |
+
|
423 |
+
# Renew Source Finder Button
|
424 |
+
if st.button('Renew Source info'):
|
425 |
+
renew_source_finder(st.session_state['domain'], st.session_state['concept_list'])
|
426 |
+
|
427 |
+
# Save the original stdout to print to the terminal if needed later
|
428 |
+
original_stdout = sys.stdout
|
429 |
+
|
430 |
+
|
431 |
+
# Define StreamToText to capture and display logs in real-time within Streamlit only
|
432 |
+
class StreamToText:
|
433 |
+
def __init__(self):
|
434 |
+
self.output = StringIO()
|
435 |
+
|
436 |
+
def write(self, message):
|
437 |
+
if message.strip(): # Avoid adding empty messages
|
438 |
+
# Only append to Streamlit display, not the terminal
|
439 |
+
st.session_state.log_messages.append(message.strip())
|
440 |
+
log_placeholder.text("\n".join(st.session_state.log_messages)) # Flush updated logs
|
441 |
+
|
442 |
+
def flush(self):
|
443 |
+
pass # Required for compatibility with sys.stdout
|
444 |
+
|
445 |
+
|
446 |
+
# Initialize session state for log messages
|
447 |
+
if 'log_messages' not in st.session_state:
|
448 |
+
st.session_state.log_messages = []
|
449 |
+
|
450 |
+
# Replace sys.stdout with our custom StreamToText instance
|
451 |
+
stream_to_text = StreamToText()
|
452 |
+
sys.stdout = stream_to_text
|
453 |
+
|
454 |
+
# Placeholder for displaying logs within a collapsible expander
|
455 |
+
with st.expander("Show Logs", expanded=False):
|
456 |
+
log_placeholder = st.empty() # Placeholder for dynamic log display
|
457 |
+
|
458 |
+
# Define the Create Benchmark button
|
459 |
+
if st.button("Create a Benchmark"):
|
460 |
+
st.session_state.log_messages = [] # Clear previous logs
|
461 |
+
with st.spinner("Creating benchmark..."):
|
462 |
+
if st.session_state['check_source_finder']:
|
463 |
+
# Check for relevant materials
|
464 |
+
check_and_create_source_files(st.session_state['domain'], st.session_state['concept_list'])
|
465 |
+
|
466 |
+
try:
|
467 |
+
# Display progress bar and log messages
|
468 |
+
progress_bar = st.progress(0)
|
469 |
+
for i in tqdm(range(1, 101)):
|
470 |
+
progress_bar.progress(i)
|
471 |
+
time.sleep(0.05) # Short delay to simulate processing time
|
472 |
+
|
473 |
+
# Run the benchmark creation function
|
474 |
+
benchmark = Pipeline.domain_benchmark_building(st.session_state['domain'], domain_configuration)
|
475 |
+
st.success("Benchmark creation completed!")
|
476 |
+
st.dataframe(benchmark.data)
|
477 |
+
|
478 |
+
except Exception as e:
|
479 |
+
st.error(f"An error occurred during benchmark creation: {e}")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ollama==0.3.3
|
2 |
+
openai==1.54.3
|
3 |
+
pandas==1.5.3
|
4 |
+
streamlit==1.40.1
|
5 |
+
tqdm==4.66.4
|
6 |
+
SAGEDbias==0.0.2
|