XinGuan2000 commited on
Commit
29a316b
·
1 Parent(s): 40ea049
Files changed (2) hide show
  1. app.py +479 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from saged import Pipeline
2
+ from tqdm import tqdm
3
+ from pathlib import Path
4
+ from saged import SAGEDData as dt
5
+ import streamlit as st
6
+ import json
7
+ import http.client
8
+ from openai import AzureOpenAI
9
+ import ollama
10
+ import time # Use time.sleep to simulate processing steps
11
+ import logging
12
+ from io import StringIO
13
+ import sys
14
+
15
+ # Create a custom logging handler to capture log messages
16
+ class StreamlitLogHandler(logging.Handler):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.log_capture_string = StringIO()
20
+
21
+ def emit(self, record):
22
+ # Write each log message to the StringIO buffer
23
+ message = self.format(record)
24
+ self.log_capture_string.write(message + "\n")
25
+
26
+ def get_logs(self):
27
+ # Return the log contents
28
+ return self.log_capture_string.getvalue()
29
+
30
+ def clear_logs(self):
31
+ # Clear the log buffer
32
+ self.log_capture_string.truncate(0)
33
+ self.log_capture_string.seek(0)
34
+
35
+
36
+
37
+ # Define ContentFormatter class
38
+ class ContentFormatter:
39
+ @staticmethod
40
+ def chat_completions(text, settings_params):
41
+ message = [
42
+ {"role": "system", "content": "You are a helpful assistant."},
43
+ {"role": "user", "content": text}
44
+ ]
45
+ data = {"messages": message, **settings_params}
46
+ return json.dumps(data)
47
+
48
+
49
+ # Define OllamaModel (For local Ollama interaction)
50
+ class OllamaModel:
51
+ def __init__(self, base_model='llama3', system_prompt='You are a helpful assistant', model_name='llama3o',
52
+ **kwargs):
53
+ self.base_model = base_model
54
+ self.model_name = model_name
55
+ self.model_create(model_name, system_prompt, base_model, **kwargs)
56
+
57
+ def model_create(self, model_name, system_prompt, base_model, **kwargs):
58
+ modelfile = f'FROM {base_model}\nSYSTEM {system_prompt}\n'
59
+ if kwargs:
60
+ for key, value in kwargs.items():
61
+ modelfile += f'PARAMETER {key.lower()} {value}\n'
62
+ ollama.create(model=model_name, modelfile=modelfile)
63
+
64
+ def invoke(self, prompt):
65
+ answer = ollama.generate(model=self.model_name, prompt=prompt)
66
+ return answer['response']
67
+
68
+
69
+ # Define GPTAgent (For OpenAI GPT models)
70
+ class GPTAgent:
71
+ def __init__(self, model_name, azure_key, azure_version, azure_endpoint, deployment_name):
72
+ self.client = AzureOpenAI(
73
+ api_key=azure_key,
74
+ api_version=azure_version,
75
+ azure_endpoint=azure_endpoint
76
+ )
77
+ self.deployment_name = deployment_name
78
+
79
+ def invoke(self, prompt, settings_params=None):
80
+ if not settings_params:
81
+ settings_params = {}
82
+ formatted_input = ContentFormatter.chat_completions(prompt, settings_params)
83
+ response = self.client.chat.completions.create(
84
+ model=self.deployment_name,
85
+ messages=json.loads(formatted_input)['messages'],
86
+ **settings_params
87
+ )
88
+ return response.choices[0].message.content
89
+
90
+
91
+ # Define AzureAgent (For Azure OpenAI models)
92
+ class AzureAgent:
93
+ def __init__(self, model_name, azure_uri, azure_api_key):
94
+ self.azure_uri = azure_uri
95
+ self.headers = {
96
+ 'Authorization': f"Bearer {azure_api_key}",
97
+ 'Content-Type': 'application/json'
98
+ }
99
+ self.chat_formatter = ContentFormatter
100
+
101
+ def invoke(self, prompt, settings_params=None):
102
+ if not settings_params:
103
+ settings_params = {}
104
+ body = self.chat_formatter.chat_completions(prompt, {**settings_params})
105
+ conn = http.client.HTTPSConnection(self.azure_uri)
106
+ conn.request("POST", '/v1/chat/completions', body=body, headers=self.headers)
107
+ response = conn.getresponse()
108
+ data = response.read()
109
+ conn.close()
110
+ decoded_data = data.decode("utf-8")
111
+ parsed_data = json.loads(decoded_data)
112
+ content = parsed_data["choices"][0]["message"]["content"]
113
+ return content
114
+
115
+
116
+
117
+ # Renew Source Finder Button
118
+ def renew_source_finder(domain, concept_list):
119
+ if 'generated_synthetic_files' in st.session_state:
120
+ del st.session_state['generated_synthetic_files']
121
+ if not domain or not concept_list:
122
+ st.error("Please fill in all the required fields before proceeding.")
123
+ else:
124
+ with st.spinner("Renewing source info files..."):
125
+ base_path = Path('data/customized/source_finder/')
126
+ for concept in concept_list:
127
+ file_path = base_path / f'{domain}_{concept}_source_finder.json'
128
+ if file_path.exists():
129
+ try:
130
+ file_path.unlink() # Delete the file
131
+ st.info(f"Deleted source info file: {file_path}")
132
+ except Exception as e:
133
+ st.error(f"An error occurred while deleting the file {file_path}: {e}")
134
+ st.success("Source info files renewal completed!")
135
+
136
+
137
+ def create_source_finder(domain, concept):
138
+ source_specification_item = f"data/customized/local_files/{domain}/{concept}.txt"
139
+ if not Path(source_specification_item).exists():
140
+ st.warning(f"Local file does not exist: {source_specification_item}")
141
+ instance = dt.create_data(domain, concept, 'source_finder')
142
+ instance.data[0]['keywords'] = {concept: dt.default_keyword_metadata.copy()}
143
+ category_shared_source_item = dt.default_source_item.copy()
144
+ category_shared_source_item['source_type'] = "local_paths"
145
+ category_shared_source_item['source_specification'] = [source_specification_item]
146
+ instance.data[0]['category_shared_source'] = [category_shared_source_item]
147
+ return instance.data.copy()
148
+
149
+
150
+ def check_and_create_source_files(domain, concept_list):
151
+ """
152
+ Checks if the required source finder files exist for each concept in the domain.
153
+ If a file does not exist or is invalid, it creates an empty JSON file for that concept.
154
+ """
155
+ base_path = Path('data/customized/source_finder/')
156
+ base_path.mkdir(parents=True, exist_ok=True)
157
+ for concept in concept_list:
158
+ file_path = base_path / f'{domain}_{concept}_source_finder.json'
159
+ if not file_path.exists():
160
+ # Create a new source finder file using create_source_finder
161
+ data = create_source_finder(domain, concept)
162
+ with open(file_path, 'w', encoding='utf-8') as f:
163
+ json.dump(data, f, indent=4)
164
+ st.info(f"Created missing source finder file: {file_path}")
165
+ else:
166
+ # Attempt to load the file to verify its validity
167
+ instance = dt.load_file(domain, concept, 'source_finder', file_path)
168
+ if instance is None:
169
+ # If loading fails, create a new valid file
170
+ data = create_source_finder(domain, concept)
171
+ with open(file_path, 'w', encoding='utf-8') as f:
172
+ json.dump(data, f, indent=4)
173
+ st.info(f"Recreated invalid source finder file: {file_path}")
174
+
175
+
176
+ def clean_spaces(data):
177
+ """
178
+ Removes trailing or leading spaces from a string or from each element in a list.
179
+ """
180
+ if isinstance(data, str):
181
+ return data.strip()
182
+ elif isinstance(data, list):
183
+ return [item.strip() if isinstance(item, str) else item for item in data]
184
+ else:
185
+ raise TypeError("Input should be either a string or a list of strings")
186
+
187
+
188
+ def create_replacement_dict(concept_list, replacer):
189
+ replacement = {}
190
+ for concept in concept_list:
191
+ replacement[concept] = {}
192
+ for company in replacer:
193
+ replacement[concept][company] = {concept: company}
194
+ return replacement
195
+
196
+
197
+ # Title of the app
198
+ st.title("SAGED-bias Benchmark-Building Demo")
199
+
200
+ # Initialize session state variables
201
+ if 'domain' not in st.session_state:
202
+ st.session_state['domain'] = None
203
+ if 'concept_list' not in st.session_state:
204
+ st.session_state['concept_list'] = None
205
+ if 'gpt_model' not in st.session_state:
206
+ st.session_state['gpt_model'] = None
207
+ if 'azure_model' not in st.session_state:
208
+ st.session_state['azure_model'] = None
209
+ if 'ollama_model' not in st.session_state:
210
+ st.session_state['ollama_model'] = None
211
+
212
+ # Sidebar: Model Selection
213
+ with st.sidebar:
214
+ st.header("Model Configuration")
215
+
216
+ # Selection of which model to use
217
+ model_selection = st.radio("Select Model Type", ['GPT-Azure', 'Azure', 'Ollama'])
218
+
219
+ # Collapsible Additional Configuration Section
220
+ with st.expander("Model Configuration"):
221
+ if model_selection == 'Ollama':
222
+ # Ollama Configuration
223
+ ollama_deployment_name = st.text_input("Enter Ollama Model Deployment Name", placeholder="e.g., llama3")
224
+ ollama_system_prompt = st.text_input("Enter System Prompt for Ollama",
225
+ placeholder="e.g., You are a helpful assistant.")
226
+
227
+ if ollama_deployment_name and ollama_system_prompt:
228
+ confirm_ollama = st.button("Confirm Ollama Configuration")
229
+ if confirm_ollama:
230
+ st.session_state['ollama_model'] = OllamaModel(
231
+ model_name=ollama_deployment_name,
232
+ system_prompt=ollama_system_prompt
233
+ )
234
+ st.success("Ollama model configured successfully.")
235
+ else:
236
+ st.warning("Please provide both Ollama deployment name and system prompt.")
237
+
238
+ elif model_selection == 'GPT-Azure' or model_selection == 'Azure':
239
+ # GPT / Azure Configuration
240
+ gpt_azure_endpoint = st.text_input("Enter Azure Endpoint URL",
241
+ placeholder="e.g., https://your-resource-name.openai.azure.com/")
242
+ gpt_azure_api_key = st.text_input("Enter Azure API Key", type="password")
243
+ gpt_azure_model_name = st.text_input("Enter Azure Model Name", placeholder="e.g., GPT-3.5-turbo")
244
+ gpt_azure_deployment_name = st.text_input("Enter Azure Deployment Name",
245
+ placeholder="e.g., gpt-3-5-deployment")
246
+
247
+ if gpt_azure_endpoint and gpt_azure_api_key and gpt_azure_model_name and gpt_azure_deployment_name:
248
+ confirm_gpt_azure = st.button("Confirm GPT/Azure Configuration")
249
+ if confirm_gpt_azure:
250
+ if model_selection == 'GPT-Azure':
251
+ st.session_state['gpt_model'] = GPTAgent(
252
+ model_name=gpt_azure_model_name,
253
+ azure_key=gpt_azure_api_key,
254
+ azure_version='2023-05-15', # Update if necessary
255
+ azure_endpoint=gpt_azure_endpoint,
256
+ deployment_name=gpt_azure_deployment_name
257
+ )
258
+ st.success("GPT model configured successfully.")
259
+ elif model_selection == 'Azure':
260
+ st.session_state['azure_model'] = AzureAgent(
261
+ model_name=gpt_azure_model_name,
262
+ azure_uri=gpt_azure_endpoint,
263
+ azure_api_key=gpt_azure_api_key
264
+ )
265
+ st.success("Azure model configured successfully.")
266
+ else:
267
+ st.warning("Please provide all fields for GPT/Azure configuration.")
268
+
269
+ # Main interaction based on configured model
270
+ if st.session_state.get('ollama_model'):
271
+ model = st.session_state['ollama_model']
272
+ elif st.session_state.get('gpt_model'):
273
+ model = st.session_state['gpt_model']
274
+ elif st.session_state.get('azure_model'):
275
+ model = st.session_state['azure_model']
276
+ else:
277
+ model = None
278
+
279
+ # User input: Domain and Concepts
280
+ with st.form(key='domain_concept_form'):
281
+ domain = clean_spaces(
282
+ st.text_input("Enter the domain: (e.g., Stocks, Education)", placeholder="Enter domain here..."))
283
+
284
+ # User input: Concepts
285
+ concept_text = st.text_area("Enter the concepts (separated by commas):",
286
+ placeholder="e.g., excel-stock, ok-stock, bad-stock")
287
+ concept_list = clean_spaces(concept_text.split(','))
288
+
289
+ submit_button = st.form_submit_button(label='Confirm Domain and Concepts')
290
+
291
+ if submit_button:
292
+ if not domain:
293
+ st.warning("Please enter a domain.")
294
+ elif not concept_list or concept_text.strip() == "":
295
+ st.warning("Please enter at least one concept.")
296
+ else:
297
+ st.session_state['domain'] = domain
298
+ st.session_state['concept_list'] = concept_list
299
+ st.success("Domain and concepts confirmed.")
300
+
301
+ # Display further options only after domain and concepts are confirmed
302
+ if st.session_state['domain'] and st.session_state['concept_list']:
303
+ with st.expander("Additional Options"):
304
+ # User input: Method
305
+ scraper_method = st.radio("Select the scraper method:", (('wiki', 'local_files', 'synthetic_files')))
306
+
307
+ # Initiate the source_finder_requirement and keyword_finder_requirement if 'wiki' is selected
308
+ if scraper_method == 'wiki':
309
+ st.session_state['keyword_finder_requirement'] = True
310
+ st.session_state['source_finder_requirement'] = True
311
+ st.session_state['check_source_finder'] = False
312
+
313
+ # File upload for each concept if 'local_files' is selected
314
+ if scraper_method == 'local_files':
315
+ uploaded_files = {}
316
+ st.session_state['keyword_finder_requirement'] = False
317
+ st.session_state['source_finder_requirement'] = False
318
+ st.session_state['check_source_finder'] = True
319
+ for concept in st.session_state['concept_list']:
320
+ uploaded_file = st.file_uploader(f"Upload file for concept '{concept}':", type=['txt'],
321
+ key=f"file_{concept}")
322
+ if uploaded_file:
323
+ uploaded_files[concept] = uploaded_file
324
+ # Save uploaded file
325
+ save_path = Path(f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt")
326
+ save_path.parent.mkdir(parents=True, exist_ok=True)
327
+ with open(save_path, 'wb') as f:
328
+ f.write(uploaded_file.getbuffer())
329
+ st.success(f"File for concept '{concept}' saved successfully.")
330
+
331
+ # Generate synthetic files if 'synthetic_files' is selected
332
+ if scraper_method == 'synthetic_files':
333
+ scraper_method = 'local_files'
334
+ st.session_state['keyword_finder_requirement'] = False
335
+ st.session_state['source_finder_requirement'] = False
336
+ st.session_state['check_source_finder'] = True
337
+ if 'generated_synthetic_files' not in st.session_state:
338
+ st.session_state['generated_synthetic_files'] = set()
339
+
340
+ prompt_inputs = {}
341
+ for concept in st.session_state['concept_list']:
342
+ if concept not in st.session_state['generated_synthetic_files']:
343
+ prompt_inputs[concept] = st.text_input(
344
+ f"Enter the prompt for concept '{concept}':",
345
+ value=f"Write a long article introducing the {concept} in the {st.session_state['domain']}. Use the {concept} as much as possible.",
346
+ key=f"prompt_{concept}"
347
+ )
348
+
349
+ if st.button("Generate Synthetic Files for All Concepts"):
350
+ if model:
351
+ for concept, prompt in prompt_inputs.items():
352
+ if prompt:
353
+ with st.spinner(f"Generating content for concept '{concept}'..."):
354
+ synthetic_content = model.invoke(prompt)
355
+ save_path = Path(
356
+ f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt")
357
+ save_path.parent.mkdir(parents=True, exist_ok=True)
358
+ with open(save_path, 'w', encoding='utf-8') as f:
359
+ f.write(synthetic_content)
360
+ st.session_state['generated_synthetic_files'].add(concept)
361
+ st.success(f"Synthetic file for concept '{concept}' created successfully.")
362
+ else:
363
+ st.warning("Please configure a model to generate synthetic files.")
364
+
365
+ # User input: Prompt Method
366
+ prompt_method = st.radio("Select the prompt method:", ('split_sentences', 'questions'), index = 0)
367
+
368
+ # User input: Max Benchmark Length
369
+ max_benchmark_length = st.slider("Select the maximum prompts per concepts:", 1, 199, 10)
370
+
371
+ # User input: Branching
372
+ branching = st.radio("Enable branching:", ('Yes', 'No'), index=1)
373
+ branching_enabled = True if branching == 'Yes' else False
374
+
375
+ # User input: Replacer (only if branching is enabled)
376
+ replacer = []
377
+ replacement = {}
378
+ if branching_enabled:
379
+ replacer_text = st.text_area("Enter the replacer list (list of strings, separated by commas):",
380
+ placeholder="e.g., Company A, Company B")
381
+ replacer = clean_spaces(replacer_text.split(','))
382
+ replacement = create_replacement_dict(st.session_state['concept_list'], replacer)
383
+
384
+ # Configuration
385
+ concept_specified_config = {
386
+ x: {'keyword_finder': {'manual_keywords': [x]}} for x in st.session_state['concept_list']
387
+ }
388
+ concept_configuration = {
389
+ 'keyword_finder': {
390
+ 'require': st.session_state['keyword_finder_requirement'],
391
+ 'keyword_number': 1,
392
+ },
393
+ 'source_finder': {
394
+ 'require': st.session_state['source_finder_requirement'],
395
+ 'scrap_number': 10,
396
+ 'method': scraper_method,
397
+ },
398
+ 'scraper': {
399
+ 'require': True,
400
+ 'method': scraper_method,
401
+ },
402
+ 'prompt_maker': {
403
+ 'method': prompt_method,
404
+ 'generation_function': model.invoke if model else None,
405
+ 'max_benchmark_length': max_benchmark_length,
406
+ },
407
+ }
408
+ domain_configuration = {
409
+ 'categories': st.session_state['concept_list'],
410
+ 'branching': branching_enabled,
411
+ 'branching_config': {
412
+ 'generation_function': model.invoke if model else None,
413
+ 'keyword_reference': st.session_state['concept_list'],
414
+ 'replacement_descriptor_require': False,
415
+ 'replacement_description': replacement,
416
+ 'branching_pairs': 'not all',
417
+ 'direction': 'not both',
418
+ },
419
+ 'shared_config': concept_configuration,
420
+ 'category_specified_config': concept_specified_config
421
+ }
422
+
423
+ # Renew Source Finder Button
424
+ if st.button('Renew Source info'):
425
+ renew_source_finder(st.session_state['domain'], st.session_state['concept_list'])
426
+
427
+ # Save the original stdout to print to the terminal if needed later
428
+ original_stdout = sys.stdout
429
+
430
+
431
+ # Define StreamToText to capture and display logs in real-time within Streamlit only
432
+ class StreamToText:
433
+ def __init__(self):
434
+ self.output = StringIO()
435
+
436
+ def write(self, message):
437
+ if message.strip(): # Avoid adding empty messages
438
+ # Only append to Streamlit display, not the terminal
439
+ st.session_state.log_messages.append(message.strip())
440
+ log_placeholder.text("\n".join(st.session_state.log_messages)) # Flush updated logs
441
+
442
+ def flush(self):
443
+ pass # Required for compatibility with sys.stdout
444
+
445
+
446
+ # Initialize session state for log messages
447
+ if 'log_messages' not in st.session_state:
448
+ st.session_state.log_messages = []
449
+
450
+ # Replace sys.stdout with our custom StreamToText instance
451
+ stream_to_text = StreamToText()
452
+ sys.stdout = stream_to_text
453
+
454
+ # Placeholder for displaying logs within a collapsible expander
455
+ with st.expander("Show Logs", expanded=False):
456
+ log_placeholder = st.empty() # Placeholder for dynamic log display
457
+
458
+ # Define the Create Benchmark button
459
+ if st.button("Create a Benchmark"):
460
+ st.session_state.log_messages = [] # Clear previous logs
461
+ with st.spinner("Creating benchmark..."):
462
+ if st.session_state['check_source_finder']:
463
+ # Check for relevant materials
464
+ check_and_create_source_files(st.session_state['domain'], st.session_state['concept_list'])
465
+
466
+ try:
467
+ # Display progress bar and log messages
468
+ progress_bar = st.progress(0)
469
+ for i in tqdm(range(1, 101)):
470
+ progress_bar.progress(i)
471
+ time.sleep(0.05) # Short delay to simulate processing time
472
+
473
+ # Run the benchmark creation function
474
+ benchmark = Pipeline.domain_benchmark_building(st.session_state['domain'], domain_configuration)
475
+ st.success("Benchmark creation completed!")
476
+ st.dataframe(benchmark.data)
477
+
478
+ except Exception as e:
479
+ st.error(f"An error occurred during benchmark creation: {e}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ollama==0.3.3
2
+ openai==1.54.3
3
+ pandas==1.5.3
4
+ streamlit==1.40.1
5
+ tqdm==4.66.4
6
+ SAGEDbias==0.0.2