Spaces:
Running
Running
phyloforfun
commited on
Commit
•
e729f97
1
Parent(s):
12ba60e
refactor
Browse files- app.py +371 -898
- vouchervision/utils.py +99 -0
app.py
CHANGED
@@ -1,24 +1,22 @@
|
|
1 |
import streamlit as st
|
2 |
-
import yaml, os, json, random, time,
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
import plotly.graph_objs as go
|
5 |
-
import numpy as np
|
6 |
from itertools import chain
|
7 |
from PIL import Image
|
8 |
from io import BytesIO
|
9 |
-
import base64
|
10 |
-
import pandas as pd
|
11 |
-
from typing import Union
|
12 |
-
from google.oauth2 import service_account
|
13 |
from streamlit_extras.let_it_rain import rain
|
14 |
-
|
15 |
-
from googleapiclient.discovery import build
|
16 |
-
from googleapiclient.http import MediaFileUpload
|
17 |
from vouchervision.LeafMachine2_Config_Builder import write_config_file
|
18 |
-
from vouchervision.VoucherVision_Config_Builder import build_VV_config
|
19 |
-
from vouchervision.vouchervision_main import voucher_vision
|
20 |
-
from vouchervision.general_utils import
|
|
|
|
|
|
|
21 |
|
|
|
|
|
|
|
22 |
PROMPTS_THAT_NEED_DOMAIN_KNOWLEDGE = ["Version 1","Version 1 PaLM 2"]
|
23 |
# LLM_VERSIONS = ["GPT 4", "GPT 3.5", "Azure GPT 4", "Azure GPT 3.5", "PaLM 2"]
|
24 |
COLORS_EXPENSE_REPORT = {
|
@@ -28,7 +26,13 @@ COLORS_EXPENSE_REPORT = {
|
|
28 |
}
|
29 |
MAX_GALLERY_IMAGES = 50
|
30 |
GALLERY_IMAGE_SIZE = 128
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
class ProgressReport:
|
33 |
def __init__(self, overall_bar, batch_bar, text_overall, text_batch):
|
34 |
self.overall_bar = overall_bar
|
@@ -67,42 +71,16 @@ class ProgressReport:
|
|
67 |
self.overall_bar.progress(0)
|
68 |
self.text_overall.text(step_name)
|
69 |
|
70 |
-
|
71 |
def get_n_images(self):
|
72 |
return self.n_images
|
73 |
def get_n_overall(self):
|
74 |
return self.total_overall_steps
|
75 |
-
|
76 |
-
def does_private_file_exist():
|
77 |
-
dir_home = os.path.dirname(os.path.dirname(__file__))
|
78 |
-
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
79 |
-
return os.path.exists(path_cfg_private)
|
80 |
-
|
81 |
-
def setup_streamlit_config(dir_home):
|
82 |
-
# Define the directory path and filename
|
83 |
-
dir_path = os.path.join(dir_home, ".streamlit")
|
84 |
-
file_path = os.path.join(dir_path, "config.toml")
|
85 |
-
|
86 |
-
# Check if directory exists, if not create it
|
87 |
-
if not os.path.exists(dir_path):
|
88 |
-
os.makedirs(dir_path)
|
89 |
|
90 |
-
# Create or modify the file with the provided content
|
91 |
-
config_content = f"""
|
92 |
-
[theme]
|
93 |
-
base = "dark"
|
94 |
-
primaryColor = "#00ff00"
|
95 |
-
|
96 |
-
[server]
|
97 |
-
enableStaticServing = false
|
98 |
-
runOnSave = true
|
99 |
-
port = 8524
|
100 |
-
maxUploadSize = 5000
|
101 |
-
"""
|
102 |
|
103 |
-
with open(file_path, "w") as f:
|
104 |
-
f.write(config_content.strip())
|
105 |
|
|
|
|
|
|
|
106 |
def display_scrollable_results(JSON_results, test_results, OPT2, OPT3):
|
107 |
"""
|
108 |
Display the results from JSON_results in a scrollable container.
|
@@ -145,6 +123,8 @@ def display_scrollable_results(JSON_results, test_results, OPT2, OPT3):
|
|
145 |
st.markdown(css, unsafe_allow_html=True)
|
146 |
st.markdown(results_html, unsafe_allow_html=True)
|
147 |
|
|
|
|
|
148 |
def display_test_results(test_results, JSON_results, llm_version):
|
149 |
if llm_version == 'gpt':
|
150 |
OPT1, OPT2, OPT3 = TestOptionsGPT.get_options()
|
@@ -182,7 +162,6 @@ def display_test_results(test_results, JSON_results, llm_version):
|
|
182 |
# Close the custom container
|
183 |
st.write('</div>', unsafe_allow_html=True)
|
184 |
|
185 |
-
|
186 |
for idx, (test_name, result) in enumerate(sorted(test_results.items())):
|
187 |
_, ind_opt1, ind_opt2, ind_opt3 = test_name.split('__')
|
188 |
opt2_readable = "Use LeafMachine2" if OPT2[int(ind_opt2.split('-')[1])] else "Don't use LeafMachine2"
|
@@ -209,9 +188,13 @@ def display_test_results(test_results, JSON_results, llm_version):
|
|
209 |
# proportional_rain("🥇", success_count, "💔", failure_count, font_size=72, falling_speed=5, animation_length="infinite")
|
210 |
rain_emojis(test_results)
|
211 |
|
|
|
|
|
212 |
def add_emoji_delay():
|
213 |
time.sleep(0.3)
|
214 |
|
|
|
|
|
215 |
def rain_emojis(test_results):
|
216 |
# test_results = {
|
217 |
# 'test1': True, # Test passed
|
@@ -251,6 +234,8 @@ def rain_emojis(test_results):
|
|
251 |
)
|
252 |
add_emoji_delay()
|
253 |
|
|
|
|
|
254 |
def get_prompt_versions(LLM_version):
|
255 |
yaml_files = [f for f in os.listdir(os.path.join(st.session_state.dir_home, 'custom_prompts')) if f.endswith('.yaml')]
|
256 |
|
@@ -264,42 +249,7 @@ def get_prompt_versions(LLM_version):
|
|
264 |
# Handle other cases or raise an error
|
265 |
return (yaml_files, None)
|
266 |
|
267 |
-
def get_private_file():
|
268 |
-
dir_home = os.path.dirname(os.path.dirname(__file__))
|
269 |
-
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
270 |
-
return get_cfg_from_full_path(path_cfg_private)
|
271 |
-
|
272 |
-
def create_space_saver():
|
273 |
-
st.subheader("Space Saving Options")
|
274 |
-
col_ss_1, col_ss_2 = st.columns([2,2])
|
275 |
-
with col_ss_1:
|
276 |
-
st.write("Several folders are created and populated with data during the VoucherVision transcription process.")
|
277 |
-
st.write("Below are several options that will allow you to automatically delete temporary files that you may not need for everyday operations.")
|
278 |
-
st.write("VoucherVision creates the following folders. Folders marked with a :star: are required if you want to use VoucherVisionEditor for quality control.")
|
279 |
-
st.write("`../[Run Name]/Archival_Components`")
|
280 |
-
st.write("`../[Run Name]/Config_File`")
|
281 |
-
st.write("`../[Run Name]/Cropped_Images` :star:")
|
282 |
-
st.write("`../[Run Name]/Logs`")
|
283 |
-
st.write("`../[Run Name]/Original_Images` :star:")
|
284 |
-
st.write("`../[Run Name]/Transcription` :star:")
|
285 |
-
with col_ss_2:
|
286 |
-
st.session_state.config['leafmachine']['project']['delete_temps_keep_VVE'] = st.checkbox("Delete Temporary Files (KEEP files required for VoucherVisionEditor)", st.session_state.config['leafmachine']['project'].get('delete_temps_keep_VVE', False))
|
287 |
-
st.session_state.config['leafmachine']['project']['delete_all_temps'] = st.checkbox("Keep only the final transcription file", st.session_state.config['leafmachine']['project'].get('delete_all_temps', False),help="*WARNING:* This limits your ability to do quality assurance. This will delete all folders created by VoucherVision, leaving only the `transcription.xlsx` file.")
|
288 |
|
289 |
-
def save_uploaded_file(directory, img_file, image=None):
|
290 |
-
if not os.path.exists(directory):
|
291 |
-
os.makedirs(directory)
|
292 |
-
# Assuming the uploaded file is an image
|
293 |
-
if image is None:
|
294 |
-
with Image.open(img_file) as image:
|
295 |
-
full_path = os.path.join(directory, img_file.name)
|
296 |
-
image.save(full_path, "JPEG")
|
297 |
-
# Return the full path of the saved image
|
298 |
-
return full_path
|
299 |
-
else:
|
300 |
-
full_path = os.path.join(directory, img_file.name)
|
301 |
-
image.save(full_path, "JPEG")
|
302 |
-
return full_path
|
303 |
|
304 |
def delete_directory(dir_path):
|
305 |
try:
|
@@ -311,305 +261,188 @@ def delete_directory(dir_path):
|
|
311 |
st.error(f"Error: {dir_path} : {e.strerror}")
|
312 |
|
313 |
|
314 |
-
# def create_private_file():
|
315 |
-
# st.session_state.proceed_to_main = False
|
316 |
-
# st.title("VoucherVision")
|
317 |
-
# col_private, _ = st.columns([12, 2])
|
318 |
-
|
319 |
-
# openai_api_key = None
|
320 |
-
# azure_openai_api_version = None
|
321 |
-
# azure_openai_api_key = None
|
322 |
-
# azure_openai_api_base = None
|
323 |
-
# azure_openai_organization = None
|
324 |
-
# azure_openai_api_type = None
|
325 |
-
# google_vision = None
|
326 |
-
# google_palm = None
|
327 |
-
|
328 |
-
# # Fetch the environment variables or set to empty if not found
|
329 |
-
# env_variables = {
|
330 |
-
# 'OPENAI_API_KEY': os.getenv('OPENAI_API_KEY'),
|
331 |
-
# 'AZURE_API_VERSION': os.getenv('AZURE_API_VERSION'),
|
332 |
-
# 'AZURE_API_KEY': os.getenv('AZURE_API_KEY'),
|
333 |
-
# 'AZURE_API_BASE': os.getenv('AZURE_API_BASE'),
|
334 |
-
# 'AZURE_ORGANIZATION': os.getenv('AZURE_ORGANIZATION'),
|
335 |
-
# 'AZURE_API_TYPE': os.getenv('AZURE_API_TYPE'),
|
336 |
-
# 'AZURE_DEPLOYMENT_NAME': os.getenv('AZURE_DEPLOYMENT_NAME'),
|
337 |
-
# 'GOOGLE_APPLICATION_CREDENTIALS': os.getenv('GOOGLE_APPLICATION_CREDENTIALS'),
|
338 |
-
# 'PALM_API_KEY': os.getenv('PALM_API_KEY')
|
339 |
-
# }
|
340 |
-
|
341 |
-
# # Check if all environment variables are set
|
342 |
-
# all_env_set = all(value is not None for value in env_variables.values())
|
343 |
-
|
344 |
-
# with col_private:
|
345 |
-
# # Your existing UI code for showing the forms goes here
|
346 |
-
# st.header("Set API keys")
|
347 |
-
# st.info("***Note:*** There is a known bug with tabs in Streamlit. If you update an input field it may take you back to the 'Project Settings' tab. Changes that you made are saved, it's just an annoying glitch. We are aware of this issue and will fix it as soon as we can.")
|
348 |
-
# st.warning("To commit changes to API keys you must press the 'Set API Keys' button at the bottom of the page.")
|
349 |
-
# st.write("Before using VoucherVision you must set your API keys. All keys are stored locally on your computer and are never made public.")
|
350 |
-
# st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
|
351 |
-
# st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface.")
|
352 |
-
# st.write("Leave keys blank if you do not intend to use that service.")
|
353 |
-
|
354 |
|
355 |
-
#
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
# - Select "Manage keys."
|
369 |
-
# - In the pop-up window, click on the "ADD KEY" button and select "JSON."
|
370 |
-
# - The JSON key file will automatically be downloaded to your computer.
|
371 |
-
# - **Store Safely**: This file contains sensitive data that can be used to authenticate and bill your Google Cloud account. Never commit it to public repositories or expose it in any way. Always keep it safe and secure.
|
372 |
-
# """)
|
373 |
-
# with st.container():
|
374 |
-
# c_in_ocr, c_button_ocr = st.columns([10,2])
|
375 |
-
# with c_in_ocr:
|
376 |
-
# google_vision = st.text_input(label = 'Full path to Google Cloud JSON API key file', value = '',
|
377 |
-
# placeholder = 'e.g. copy contents of file application_default_credentials.json',
|
378 |
-
# help ="This API Key is in the form of a JSON file. Please save the JSON file in a safe directory. DO NOT store the JSON key inside of the VoucherVision directory.",
|
379 |
-
# type='password',key='924857298734590283750932809238')
|
380 |
-
# st.secrets["db_username"]
|
381 |
-
# with c_button_ocr:
|
382 |
-
# st.empty()
|
383 |
-
|
384 |
-
# with st.container():
|
385 |
-
# with c_button_ocr:
|
386 |
-
# st.write("##")
|
387 |
-
# st.button("Test OCR", on_click=test_API, args=['google_vision',c_in_ocr,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
388 |
-
# azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
389 |
-
|
390 |
-
|
391 |
-
# if os.getenv('OPENAI_API_KEY') is None:
|
392 |
-
# st.write("---")
|
393 |
-
# st.subheader("OpenAI")
|
394 |
-
# st.markdown("API key for first-party OpenAI API. Create an account with OpenAI [here](https://platform.openai.com/signup), then create an API key [here](https://platform.openai.com/account/api-keys).")
|
395 |
-
# with st.container():
|
396 |
-
# c_in_openai, c_button_openai = st.columns([10,2])
|
397 |
-
# with c_in_openai:
|
398 |
-
# openai_api_key = st.text_input("openai_api_key", os.environ.get('OPENAI_API_KEY', ''),
|
399 |
-
# help='The actual API key. Likely to be a string of 2 character, a dash, and then a 48-character string: sk-XXXXXXXX...',
|
400 |
-
# placeholder = 'e.g. sk-XXXXXXXX...',
|
401 |
-
# type='password')
|
402 |
-
# with c_button_openai:
|
403 |
-
# st.empty()
|
404 |
-
# with st.container():
|
405 |
-
# with c_button_openai:
|
406 |
-
# st.write("##")
|
407 |
-
# st.button("Test OpenAI", on_click=test_API, args=['openai',c_in_openai,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
408 |
-
# azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
409 |
-
|
410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
-
#
|
413 |
-
|
414 |
-
# st.subheader("OpenAI - Azure")
|
415 |
-
# st.markdown("This version OpenAI relies on Azure servers directly as is intended for private enterprise instances of OpenAI's services, such as [UM-GPT](https://its.umich.edu/computing/ai). Administrators will provide you with the following information.")
|
416 |
-
# azure_openai_api_version = st.text_input("azure_openai_api_version", os.environ.get('AZURE_API_VERSION', ''),
|
417 |
-
# help='API Version e.g. "2023-05-15"',
|
418 |
-
# placeholder = 'e.g. 2023-05-15',
|
419 |
-
# type='password')
|
420 |
-
# azure_openai_api_key = st.text_input("azure_openai_api_key", os.environ.get('AZURE_API_KEY', ''),
|
421 |
-
# help='The actual API key. Likely to be a 32-character string',
|
422 |
-
# placeholder = 'e.g. 12333333333333333333333333333332',
|
423 |
-
# type='password')
|
424 |
-
# azure_openai_api_base = st.text_input("azure_openai_api_base", os.environ.get('AZURE_API_BASE', ''),
|
425 |
-
# help='The base url for the API e.g. "https://api.umgpt.umich.edu/azure-openai-api"',
|
426 |
-
# placeholder = 'e.g. https://api.umgpt.umich.edu/azure-openai-api',
|
427 |
-
# type='password')
|
428 |
-
# azure_openai_organization = st.text_input("azure_openai_organization", os.environ.get('AZURE_ORGANIZATION', ''),
|
429 |
-
# help='Your organization code. Likely a short string',
|
430 |
-
# placeholder = 'e.g. 123456',
|
431 |
-
# type='password')
|
432 |
-
# azure_openai_api_type = st.text_input("azure_openai_api_type", os.environ.get('AZURE_API_TYPE', ''),
|
433 |
-
# help='The API type. Typically "azure"',
|
434 |
-
# placeholder = 'e.g. azure',
|
435 |
-
# type='password')
|
436 |
-
# with st.container():
|
437 |
-
# c_in_azure, c_button_azure = st.columns([10,2])
|
438 |
-
# with c_button_azure:
|
439 |
-
# st.empty()
|
440 |
-
# with st.container():
|
441 |
-
# with c_button_azure:
|
442 |
-
# st.write("##")
|
443 |
-
# st.button("Test Azure OpenAI", on_click=test_API, args=['azure_openai',c_in_azure,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
444 |
-
# azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
445 |
-
|
446 |
|
447 |
|
448 |
-
# if os.getenv('PALM_API_KEY') is None:
|
449 |
-
# st.write("---")
|
450 |
-
# st.subheader("Google PaLM 2")
|
451 |
-
# st.markdown('Follow these [instructions](https://developers.generativeai.google/tutorials/setup) to generate an API key for PaLM 2. You may need to also activate an account with [MakerSuite](https://makersuite.google.com/app/apikey) and enable "early access."')
|
452 |
-
# with st.container():
|
453 |
-
# c_in_palm, c_button_palm = st.columns([10,2])
|
454 |
-
# with c_in_palm:
|
455 |
-
# google_palm = st.text_input("Google PaLM 2 API Key", os.environ.get('PALM_API_KEY', ''),
|
456 |
-
# help='The MakerSuite API key e.g. a 32-character string',
|
457 |
-
# placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
458 |
-
# type='password')
|
459 |
-
# with st.container():
|
460 |
-
# with c_button_palm:
|
461 |
-
# st.write("##")
|
462 |
-
# st.button("Test PaLM 2", on_click=test_API, args=['palm',c_in_palm,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
463 |
-
# azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
464 |
-
|
465 |
-
# st.button("Set API Keys",type='primary', on_click=set_API_keys, args=[openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
466 |
-
# azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
467 |
-
|
468 |
-
# # # UI form for entering environment variables if not all are set
|
469 |
-
# # with st.form("env_variables"):
|
470 |
-
# # for var, value in env_variables.items():
|
471 |
-
# # env_variables[var] = st.text_input(f"Enter {var}", value or "")
|
472 |
-
# # submitted = st.form_submit_button("Submit")
|
473 |
-
# # if submitted:
|
474 |
-
# # # Assuming the environment variables should be set for the session
|
475 |
-
# # for var, value in env_variables.items():
|
476 |
-
# # os.environ[var] = value
|
477 |
-
# # st.success("Environment variables updated. Please restart your app.")
|
478 |
-
# if st.button('Proceed to VoucherVision'):
|
479 |
-
# st.session_state.proceed_to_private = False
|
480 |
-
# st.session_state.proceed_to_main = True
|
481 |
-
|
482 |
-
# def set_API_keys(openai_api_key, azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type, google_vision, google_palm):
|
483 |
-
# # Set the environment variable if the key is not None or an empty string
|
484 |
-
# if openai_api_key:
|
485 |
-
# os.environ['OPENAI_API_KEY'] = openai_api_key
|
486 |
-
# if azure_openai_api_version:
|
487 |
-
# os.environ['AZURE_API_VERSION'] = azure_openai_api_version
|
488 |
-
# if azure_openai_api_key:
|
489 |
-
# os.environ['AZURE_API_KEY'] = azure_openai_api_key
|
490 |
-
# if azure_openai_api_base:
|
491 |
-
# os.environ['AZURE_API_BASE'] = azure_openai_api_base
|
492 |
-
# if azure_openai_organization:
|
493 |
-
# os.environ['AZURE_ORGANIZATION'] = azure_openai_organization
|
494 |
-
# if azure_openai_api_type:
|
495 |
-
# os.environ['AZURE_API_TYPE'] = azure_openai_api_type
|
496 |
-
# if google_vision:
|
497 |
-
# os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = google_vision
|
498 |
-
# if google_palm:
|
499 |
-
# os.environ['GOOGLE_PALM_API'] = google_palm
|
500 |
-
|
501 |
-
# st.success("API keys set successfully!")
|
502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
#
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
#
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
# """
|
578 |
-
# # Initialize the container
|
579 |
-
# con_image = st.empty()
|
580 |
-
# with con_image.container():
|
581 |
-
# # Loop through each image in the input list
|
582 |
-
# for image_path in st.session_state['input_list']:
|
583 |
-
# img = Image.open(image_path)
|
584 |
-
# img.thumbnail((120, 120), Image.Resampling.LANCZOS)
|
585 |
|
586 |
-
#
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
|
614 |
def show_available_APIs():
|
615 |
st.session_state['has_key_openai'] = (os.getenv('OPENAI_API_KEY') is not None) and (os.getenv('OPENAI_API_KEY') != '')
|
@@ -680,192 +513,143 @@ def display_image_gallery():
|
|
680 |
# Apply the CSS
|
681 |
st.markdown(css, unsafe_allow_html=True)
|
682 |
|
683 |
-
def
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
cfg_private['openai_azure']['api_version'] = azure_openai_api_version
|
689 |
-
cfg_private['openai_azure']['openai_api_key'] = azure_openai_api_key
|
690 |
-
cfg_private['openai_azure']['openai_api_base'] = azure_openai_api_base
|
691 |
-
cfg_private['openai_azure']['openai_organization'] = azure_openai_organization
|
692 |
-
cfg_private['openai_azure']['openai_api_type'] = azure_openai_api_type
|
693 |
|
694 |
-
cfg_private['google_cloud']['path_json_file'] = google_vision
|
695 |
|
696 |
-
cfg_private['google_palm']['google_palm_api'] = google_palm
|
697 |
-
# Call the function to write the updated configuration to the YAML file
|
698 |
-
write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
|
699 |
-
st.session_state.private_file = does_private_file_exist()
|
700 |
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
st.session_state['instructions'] = st.session_state['prompt_info'].get('instructions', st.session_state['default_instructions'])
|
711 |
-
st.session_state['json_formatting_instructions'] = st.session_state['prompt_info'].get('json_formatting_instructions', st.session_state['default_json_formatting_instructions'] )
|
712 |
-
st.session_state['rules'] = st.session_state['prompt_info'].get('rules', {})
|
713 |
-
st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
|
714 |
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
'prompt_description': st.session_state['prompt_description'],
|
719 |
-
'LLM': st.session_state['LLM'],
|
720 |
-
'instructions': st.session_state['instructions'],
|
721 |
-
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
722 |
-
'rules': st.session_state['rules'],
|
723 |
-
'mapping': st.session_state['mapping'],
|
724 |
-
}
|
725 |
|
726 |
-
#
|
727 |
-
|
|
|
|
|
728 |
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
'prompt_description': st.session_state['prompt_description'],
|
734 |
-
'LLM': st.session_state['LLM'],
|
735 |
-
'instructions': st.session_state['instructions'],
|
736 |
-
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
737 |
-
'rules': st.session_state['rules'],
|
738 |
-
'mapping': st.session_state['mapping'],
|
739 |
-
}
|
740 |
-
|
741 |
-
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
742 |
-
filepath = os.path.join(dir_prompt, f"{filename}.yaml")
|
743 |
|
744 |
-
|
745 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
746 |
|
747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
|
749 |
-
|
|
|
|
|
|
|
750 |
|
751 |
-
|
752 |
-
|
753 |
|
754 |
-
#
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
creds_info, scopes=["https://www.googleapis.com/auth/drive"]
|
761 |
)
|
762 |
-
|
763 |
-
|
764 |
-
# Get the folder ID from the environment variable
|
765 |
-
folder_id = os.environ.get('GDRIVE')
|
766 |
-
# st.info(f"{folder_id}")
|
767 |
-
|
768 |
-
if folder_id:
|
769 |
-
file_metadata = {
|
770 |
-
'name': filename,
|
771 |
-
'parents': [folder_id]
|
772 |
-
}
|
773 |
-
# st.info(f"{file_metadata}")
|
774 |
-
|
775 |
-
media = MediaFileUpload(filepath, mimetype='application/x-yaml')
|
776 |
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
).execute()
|
782 |
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
|
|
|
|
|
|
790 |
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
# The hyphen - is literally matched.
|
797 |
|
798 |
-
|
799 |
-
|
800 |
-
else:
|
801 |
-
return False
|
802 |
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
)
|
812 |
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
st.
|
821 |
-
st.session_state['prompt_author_institution'] = st.session_state['default_prompt_author_institution']
|
822 |
-
st.session_state['prompt_description'] = st.session_state['default_prompt_description']
|
823 |
-
st.session_state['instructions'] = st.session_state['default_instructions']
|
824 |
-
st.session_state['json_formatting_instructions'] = st.session_state['default_json_formatting_instructions']
|
825 |
-
st.session_state['rules'] = {}
|
826 |
-
st.session_state['LLM'] = 'gpt'
|
827 |
-
|
828 |
-
st.session_state['assigned_columns'] = []
|
829 |
|
830 |
-
|
831 |
-
|
832 |
-
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
833 |
-
'prompt_description': st.session_state['prompt_description'],
|
834 |
-
'LLM': st.session_state['LLM'],
|
835 |
-
'instructions': st.session_state['instructions'],
|
836 |
-
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
837 |
-
'rules': st.session_state['rules'],
|
838 |
-
'mapping': st.session_state['mapping'],
|
839 |
-
}
|
840 |
|
841 |
-
def refresh():
|
842 |
-
st.write('')
|
843 |
|
844 |
-
def upload_local_prompt_to_server(dir_prompt):
|
845 |
-
uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
|
846 |
-
if uploaded_file is not None:
|
847 |
-
# Check the file extension
|
848 |
-
file_name = uploaded_file.name
|
849 |
-
if file_name.endswith('.yaml'):
|
850 |
-
file_path = os.path.join(dir_prompt, file_name)
|
851 |
-
|
852 |
-
# Save the file
|
853 |
-
with open(file_path, 'wb') as f:
|
854 |
-
f.write(uploaded_file.getbuffer())
|
855 |
-
st.success(f"Saved file {file_name} in {dir_prompt}")
|
856 |
-
else:
|
857 |
-
st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
|
858 |
|
859 |
-
def
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
)
|
868 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
def build_LLM_prompt_config():
|
870 |
st.session_state['assigned_columns'] = []
|
871 |
st.session_state['default_prompt_author'] = 'unknown'
|
@@ -895,10 +679,8 @@ The desired null value is also given. Populate the field with the null value of
|
|
895 |
|
896 |
# Start building the Streamlit app
|
897 |
col_prompt_main_left, ___, col_prompt_main_right = st.columns([6,1,3])
|
898 |
-
|
899 |
|
900 |
with col_prompt_main_left:
|
901 |
-
|
902 |
st.title("Custom LLM Prompt Builder")
|
903 |
st.subheader('About')
|
904 |
st.write("This form allows you to craft a prompt for your specific task.")
|
@@ -913,7 +695,6 @@ The desired null value is also given. Populate the field with the null value of
|
|
913 |
st.write("4. When you go back the main VoucherVision page you will now see your custom prompt available in the 'Prompt Version' dropdown menu.")
|
914 |
st.write("5. Select your custom prompt. Note, your prompt will only be available for the LLM that you set when filling out the form below.")
|
915 |
|
916 |
-
|
917 |
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
918 |
yaml_files = [f for f in os.listdir(dir_prompt) if f.endswith('.yaml')]
|
919 |
col_upload_yaml, col_upload_yaml_2 = st.columns([4,4])
|
@@ -938,7 +719,6 @@ The desired null value is also given. Populate the field with the null value of
|
|
938 |
# Create the download button
|
939 |
st.write('##')
|
940 |
create_download_button_yaml(download_file_path, st.session_state['selected_yaml_file'] )
|
941 |
-
|
942 |
|
943 |
# Prompt Author Information
|
944 |
st.header("Prompt Author Information")
|
@@ -950,18 +730,16 @@ The desired null value is also given. Populate the field with the null value of
|
|
950 |
|
951 |
st.write("Please provide a description of your prompt and its intended task. Is it designed for a specific collection? Taxa? Database structure?")
|
952 |
st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description']))
|
953 |
-
|
954 |
|
955 |
-
st.write('---')
|
956 |
# Input for new file name
|
|
|
957 |
st.header("Prompt Name")
|
958 |
st.write('Provide a name for your custom prompt. It can only conatin letters, numbers, and underscores. No spaces, dashes, or special characters.')
|
959 |
st.session_state['new_prompt_yaml_filename'] = st.text_input("Enter filename to save your prompt as a configuration YAML:", value=None, placeholder='my_prompt_name')
|
960 |
|
961 |
-
|
962 |
st.write('---')
|
963 |
st.header("Set LLM Model Type")
|
964 |
-
# Define the options for the dropdown
|
965 |
llm_options = ['gpt', 'palm']
|
966 |
# Create the dropdown and set the value to session_state['LLM']
|
967 |
st.write("Which LLM is this prompt designed for? This will not restrict its use to a specific LLM, but some prompts will behave in different ways across models.")
|
@@ -982,10 +760,6 @@ The desired null value is also given. Populate the field with the null value of
|
|
982 |
st.write("The following section tells the LLM how we want to structure the JSON dictionary. We do not recommend changing this section because it would likely result in unstable and inconsistent behavior.")
|
983 |
st.session_state['json_formatting_instructions'] = st.text_area("Enter column instructions:", value=st.session_state['default_json_formatting_instructions'], height=350, disabled=True)
|
984 |
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
st.write('---')
|
990 |
col_left, col_right = st.columns([6,4])
|
991 |
with col_left:
|
@@ -1004,12 +778,8 @@ The desired null value is also given. Populate the field with the null value of
|
|
1004 |
"taxonomy": ["Genus_species"]
|
1005 |
}
|
1006 |
|
1007 |
-
# Layout for adding a new column name
|
1008 |
-
# col_text, col_textbtn = st.columns([8, 2])
|
1009 |
-
# with col_text:
|
1010 |
new_column_name = st.text_input("Enter a new column name:")
|
1011 |
-
|
1012 |
-
# st.write('##')
|
1013 |
if st.button("Add New Column") and new_column_name:
|
1014 |
if new_column_name not in st.session_state['rules']['Dictionary']:
|
1015 |
st.session_state['rules']['Dictionary'][new_column_name] = {"format": "", "null_value": "", "description": ""}
|
@@ -1032,9 +802,6 @@ The desired null value is also given. Populate the field with the null value of
|
|
1032 |
if 'selected_column' not in st.session_state:
|
1033 |
st.session_state['selected_column'] = column_name
|
1034 |
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
# Form for input fields
|
1039 |
with st.form(key='rule_form'):
|
1040 |
format_options = ["verbatim transcription", "spell check transcription", "boolean yes no", "boolean 1 0", "integer", "[list]", "yyyy-mm-dd"]
|
@@ -1063,43 +830,18 @@ The desired null value is also given. Populate the field with the null value of
|
|
1063 |
# Force the form to reset by clearing the fields from the session state
|
1064 |
st.session_state.pop('selected_column', None) # Clear the selected column to force reset
|
1065 |
|
1066 |
-
# st.session_state['rules'][column_name] = current_rule
|
1067 |
-
# st.success(f"Column '{column_name}' added/updated in rules.")
|
1068 |
-
|
1069 |
-
# # Reset current_rule to default values for the next input
|
1070 |
-
# current_rule["format"] = default_rule["format"]
|
1071 |
-
# current_rule["null_value"] = default_rule["null_value"]
|
1072 |
-
# current_rule["description"] = default_rule["description"]
|
1073 |
-
|
1074 |
-
# # To ensure that the form fields are reset, we can clear them from the session state
|
1075 |
-
# for key in current_rule.keys():
|
1076 |
-
# st.session_state[key] = default_rule[key]
|
1077 |
-
|
1078 |
# Layout for removing an existing column
|
1079 |
-
# del_col, del_colbtn = st.columns([8, 2])
|
1080 |
-
# with del_col:
|
1081 |
delete_column_name = st.selectbox("Select a column to delete:", [""] + editable_columns, key='delete_column')
|
1082 |
-
# with del_colbtn:
|
1083 |
-
# st.write('##')
|
1084 |
if st.button("Delete Column") and delete_column_name:
|
1085 |
del st.session_state['rules'][delete_column_name]
|
1086 |
st.success(f"Column '{delete_column_name}' removed from rules.")
|
1087 |
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
with col_right:
|
1092 |
# Display the current state of the JSON rules
|
1093 |
st.subheader('Formatted Columns')
|
1094 |
st.json(st.session_state['rules']['Dictionary'])
|
1095 |
|
1096 |
-
# st.subheader('All Prompt Info')
|
1097 |
-
# st.json(st.session_state['prompt_info'])
|
1098 |
-
|
1099 |
-
|
1100 |
st.write('---')
|
1101 |
-
|
1102 |
-
|
1103 |
col_left_mapping, col_right_mapping = st.columns([6,4])
|
1104 |
with col_left_mapping:
|
1105 |
st.header("Mapping")
|
@@ -1157,7 +899,6 @@ The desired null value is also given. Populate the field with the null value of
|
|
1157 |
st.subheader('Formatted Column Maps')
|
1158 |
st.json(st.session_state['mapping'])
|
1159 |
|
1160 |
-
|
1161 |
st.write('---')
|
1162 |
st.header("Save and Download Custom Prompt")
|
1163 |
st.write('Once you click save, validation checks will verify the formatting and then a download button will appear so that you can ***save a local copy of your custom prompt.***')
|
@@ -1213,50 +954,37 @@ The desired null value is also given. Populate the field with the null value of
|
|
1213 |
}
|
1214 |
st.json(st.session_state['prompt_info'])
|
1215 |
|
1216 |
-
def show_header_welcome():
|
1217 |
-
st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
|
1218 |
-
st.session_state.logo = Image.open(st.session_state.logo_path)
|
1219 |
-
st.image(st.session_state.logo, width=250)
|
1220 |
-
|
1221 |
-
def determine_n_images():
|
1222 |
-
try:
|
1223 |
-
# Check if 'dir_uploaded_images' key exists and it is not empty
|
1224 |
-
if 'dir_uploaded_images' in st and st['dir_uploaded_images']:
|
1225 |
-
dir_path = st['dir_uploaded_images'] # This would be the path to the directory
|
1226 |
-
return len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
|
1227 |
-
else:
|
1228 |
-
return None
|
1229 |
-
except:
|
1230 |
-
return None
|
1231 |
-
|
1232 |
def content_header():
|
|
|
1233 |
col_run_1, col_run_2, col_run_3, col_run_4 = st.columns([2,2,2,2])
|
1234 |
|
1235 |
-
|
1236 |
-
st.subheader("Overall Progress")
|
1237 |
col_run_info_1 = st.columns([1])[0]
|
1238 |
|
1239 |
-
st.write("")
|
1240 |
-
st.header("Configuration Settings")
|
1241 |
-
|
1242 |
with col_run_info_1:
|
1243 |
# Progress
|
|
|
1244 |
overall_progress_bar = st.progress(0)
|
1245 |
text_overall = st.empty() # Placeholder for current step name
|
|
|
1246 |
st.subheader('Transcription Progress')
|
1247 |
batch_progress_bar = st.progress(0)
|
1248 |
text_batch = st.empty() # Placeholder for current step name
|
|
|
1249 |
progress_report = ProgressReport(overall_progress_bar, batch_progress_bar, text_overall, text_batch)
|
|
|
1250 |
st.info("***Note:*** There is a known bug with tabs in Streamlit. If you update an input field it may take you back to the 'Project Settings' tab. Changes that you made are saved, it's just an annoying glitch. We are aware of this issue and will fix it as soon as we can.")
|
1251 |
st.write("If you use VoucherVision frequently, you can change the default values that are auto-populated in the form below. In a text editor or IDE, edit the first few rows in the file `../VoucherVision/vouchervision/VoucherVision_Config_Builder.py`")
|
1252 |
-
|
1253 |
|
1254 |
with col_run_1:
|
1255 |
show_header_welcome()
|
1256 |
st.subheader('Run VoucherVision')
|
1257 |
-
N_STEPS = 6
|
1258 |
|
1259 |
-
if check_if_usable():
|
|
|
|
|
|
|
|
|
1260 |
if st.button(f"Start Processing{st.session_state['processing_add_on']}", type='primary'):
|
1261 |
|
1262 |
# First, write the config file.
|
@@ -1265,7 +993,7 @@ def content_header():
|
|
1265 |
path_custom_prompts = os.path.join(st.session_state.dir_home,'custom_prompts',st.session_state.config['leafmachine']['project']['prompt_version'])
|
1266 |
|
1267 |
# Define number of overall steps
|
1268 |
-
progress_report.set_n_overall(
|
1269 |
progress_report.update_overall(f"Starting VoucherVision...")
|
1270 |
|
1271 |
# Call the machine function.
|
@@ -1287,12 +1015,6 @@ def content_header():
|
|
1287 |
|
1288 |
if st.session_state['zip_filepath']:
|
1289 |
create_download_button(st.session_state['zip_filepath'])
|
1290 |
-
|
1291 |
-
|
1292 |
-
else:
|
1293 |
-
st.button("Start Processing", type='primary', disabled=True)
|
1294 |
-
# st.error(":heavy_exclamation_mark: Required API keys not set. Please visit the 'API Keys' tab and set the Google Vision OCR API key and at least one LLM key.")
|
1295 |
-
st.error(":heavy_exclamation_mark: Required API keys not set. Please set the API keys as 'Secrets' for your Hugging Face Space. Visit the 'Settings' tab at the top of the page.")
|
1296 |
st.button("Refresh", on_click=refresh)
|
1297 |
|
1298 |
with col_run_2:
|
@@ -1309,51 +1031,21 @@ def content_header():
|
|
1309 |
st.write('8. Start processing --- Wait for VoucherVision to finish.')
|
1310 |
st.write('9. Download results --- Click the "Download Results" button to save the VoucherVision output to your computer. ***Output files will disappear if you start a new run or restart the Space.***')
|
1311 |
st.write('10. Editing the LLM transcriptions --- Use the VoucherVisionEditor to revise and correct any mistakes or ommissions.')
|
1312 |
-
# st.subheader('Run Tests', help="")
|
1313 |
-
# st.write('We include a single image for testing. If you want to test all of the available prompts and LLMs on a different set of images, copy your images into `../VoucherVision/demo/demo_images`.')
|
1314 |
-
# if st.button("Test GPT",disabled=True):
|
1315 |
-
# progress_report.set_n_overall(TestOptionsGPT.get_length())
|
1316 |
-
# test_results, JSON_results = run_demo_tests_GPT(progress_report)
|
1317 |
-
# with col_test:
|
1318 |
-
# display_test_results(test_results, JSON_results, 'gpt')
|
1319 |
-
# st.balloons()
|
1320 |
-
|
1321 |
-
# if st.button("Test PaLM2",disabled=True):
|
1322 |
-
# progress_report.set_n_overall(TestOptionsPalm.get_length())
|
1323 |
-
# test_results, JSON_results = run_demo_tests_Palm(progress_report)
|
1324 |
-
# with col_test:
|
1325 |
-
# display_test_results(test_results, JSON_results, 'palm')
|
1326 |
-
# st.balloons()
|
1327 |
|
1328 |
with col_run_4:
|
1329 |
st.subheader('Available LLMs and APIs')
|
1330 |
show_available_APIs()
|
1331 |
st.info('Until the end of 2023, Azure OpenAI models will be available for anyone to use here. Then only PaLM 2 will be available. To use all services, duplicate this Space and provide your own API keys.')
|
1332 |
-
# st.subheader('Check GPU')
|
1333 |
-
# if st.button("GPU"):
|
1334 |
-
# success, info = test_GPU()
|
1335 |
-
|
1336 |
-
# if success:
|
1337 |
-
# st.balloons()
|
1338 |
-
# for message in info:
|
1339 |
-
# st.success(message)
|
1340 |
-
# else:
|
1341 |
-
# for message in info:
|
1342 |
-
# st.error(message)
|
1343 |
|
1344 |
-
def clear_image_gallery():
|
1345 |
-
delete_directory(st.session_state['dir_uploaded_images'])
|
1346 |
-
delete_directory(st.session_state['dir_uploaded_images_small'])
|
1347 |
-
validate_dir(st.session_state['dir_uploaded_images'])
|
1348 |
-
validate_dir(st.session_state['dir_uploaded_images_small'])
|
1349 |
|
1350 |
-
|
1351 |
-
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
1355 |
-
|
1356 |
def content_tab_settings():
|
|
|
|
|
1357 |
col_project_1, col_project_2, col_project_3 = st.columns([2,2,2])
|
1358 |
|
1359 |
st.write("---")
|
@@ -1374,7 +1066,6 @@ def content_tab_settings():
|
|
1374 |
st.subheader('Run name')
|
1375 |
st.session_state.config['leafmachine']['project']['run_name'] = st.text_input("Run name", st.session_state.config['leafmachine']['project'].get('run_name', ''),
|
1376 |
label_visibility='collapsed')
|
1377 |
-
# st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
|
1378 |
st.write("Run name will be the name of the final zipped folder.")
|
1379 |
|
1380 |
### LLM Version
|
@@ -1416,6 +1107,7 @@ def content_tab_settings():
|
|
1416 |
if st.button("Build Custom LLM Prompt",help="It may take a moment for the page to refresh."):
|
1417 |
st.session_state.proceed_to_build_llm_prompt = True
|
1418 |
st.rerun()
|
|
|
1419 |
### Input Images Local
|
1420 |
with col_local_1:
|
1421 |
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
@@ -1446,6 +1138,7 @@ def content_tab_settings():
|
|
1446 |
|
1447 |
st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
|
1448 |
|
|
|
1449 |
with col_local_2:
|
1450 |
if st.session_state['input_list_small']:
|
1451 |
st.subheader('Image Gallery')
|
@@ -1456,11 +1149,7 @@ def content_tab_settings():
|
|
1456 |
# If there are less than 100 images, take them all
|
1457 |
images_to_display = st.session_state['input_list_small']
|
1458 |
st.image(images_to_display)
|
1459 |
-
|
1460 |
-
# st.image(st.session_state['input_list_small'])
|
1461 |
-
# display_image_gallery()
|
1462 |
-
# st.button("Clear Staged Images",on_click=delete_directory, args=[st.session_state['dir_uploaded_images']])
|
1463 |
-
|
1464 |
with col_cropped_1:
|
1465 |
default_crops = st.session_state.config['leafmachine']['cropped_components'].get('save_cropped_annotations', ['leaf_whole'])
|
1466 |
st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. (Requires GPU)")
|
@@ -1484,301 +1173,85 @@ def content_tab_settings():
|
|
1484 |
image_ocr = Image.open(ocr)
|
1485 |
st.image(image_ocr, caption='OCR Overlay Images', output_format = "PNG")
|
1486 |
|
1487 |
-
def content_tab_component():
|
1488 |
-
st.header('Archival Components')
|
1489 |
-
ACD_version = st.selectbox("Archival Component Detector (ACD) Version", ["Version 2.1", "Version 2.2"])
|
1490 |
-
|
1491 |
-
ACD_confidence_default = int(st.session_state.config['leafmachine']['archival_component_detector']['minimum_confidence_threshold'] * 100)
|
1492 |
-
ACD_confidence = st.number_input("ACD Confidence Threshold (%)", min_value=0, max_value=100,value=ACD_confidence_default)
|
1493 |
-
st.session_state.config['leafmachine']['archival_component_detector']['minimum_confidence_threshold'] = float(ACD_confidence/100)
|
1494 |
-
|
1495 |
-
st.session_state.config['leafmachine']['archival_component_detector']['do_save_prediction_overlay_images'] = st.checkbox("Save Archival Prediction Overlay Images", st.session_state.config['leafmachine']['archival_component_detector'].get('do_save_prediction_overlay_images', True))
|
1496 |
-
|
1497 |
-
st.session_state.config['leafmachine']['archival_component_detector']['ignore_objects_for_overlay'] = st.multiselect("Hide Archival Components in Prediction Overlay Images",
|
1498 |
-
['ruler', 'barcode','label', 'colorcard','map','envelope','photo','attached_item','weights',],
|
1499 |
-
default=[])
|
1500 |
-
|
1501 |
-
# Depending on the selected version, set the configuration
|
1502 |
-
if ACD_version == "Version 2.1":
|
1503 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_type'] = 'Archival_Detector'
|
1504 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_version'] = 'PREP_final'
|
1505 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_iteration'] = 'PREP_final'
|
1506 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_weights'] = 'best.pt'
|
1507 |
-
elif ACD_version == "Version 2.2": #TODO update this to version 2.2
|
1508 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_type'] = 'Archival_Detector'
|
1509 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_version'] = 'PREP_final'
|
1510 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_iteration'] = 'PREP_final'
|
1511 |
-
st.session_state.config['leafmachine']['archival_component_detector']['detector_weights'] = 'best.pt'
|
1512 |
-
|
1513 |
-
|
1514 |
-
def content_tab_processing():
|
1515 |
-
st.header('Processing Options')
|
1516 |
-
col_processing_1, col_processing_2 = st.columns([2,2,])
|
1517 |
-
with col_processing_1:
|
1518 |
-
st.subheader('Compute Options')
|
1519 |
-
st.session_state.config['leafmachine']['project']['num_workers'] = st.number_input("Number of CPU workers", value=st.session_state.config['leafmachine']['project'].get('num_workers', 1), disabled=True)
|
1520 |
-
st.session_state.config['leafmachine']['project']['batch_size'] = st.number_input("Batch size", value=st.session_state.config['leafmachine']['project'].get('batch_size', 500), help='Sets the batch size for the LeafMachine2 cropping. If computer RAM is filled, lower this value to ~100.')
|
1521 |
-
with col_processing_2:
|
1522 |
-
st.subheader('Misc')
|
1523 |
-
st.session_state.config['leafmachine']['project']['prefix_removal'] = st.text_input("Remove prefix from catalog number", st.session_state.config['leafmachine']['project'].get('prefix_removal', ''))
|
1524 |
-
st.session_state.config['leafmachine']['project']['suffix_removal'] = st.text_input("Remove suffix from catalog number", st.session_state.config['leafmachine']['project'].get('suffix_removal', ''))
|
1525 |
-
st.session_state.config['leafmachine']['project']['catalog_numerical_only'] = st.checkbox("Require 'Catalog Number' to be numerical only", st.session_state.config['leafmachine']['project'].get('catalog_numerical_only', True))
|
1526 |
-
|
1527 |
-
### Logging and Image Validation - col_v1
|
1528 |
-
st.header('Logging and Image Validation')
|
1529 |
-
col_v1, col_v2 = st.columns(2)
|
1530 |
-
with col_v1:
|
1531 |
-
st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'] = st.checkbox("Check for illegal filenames", st.session_state.config['leafmachine']['do'].get('check_for_illegal_filenames', True))
|
1532 |
-
st.session_state.config['leafmachine']['do']['check_for_corrupt_images_make_vertical'] = st.checkbox("Check for corrupt images", st.session_state.config['leafmachine']['do'].get('check_for_corrupt_images_make_vertical', True))
|
1533 |
-
|
1534 |
-
st.session_state.config['leafmachine']['print']['verbose'] = st.checkbox("Print verbose", st.session_state.config['leafmachine']['print'].get('verbose', True))
|
1535 |
-
st.session_state.config['leafmachine']['print']['optional_warnings'] = st.checkbox("Show optional warnings", st.session_state.config['leafmachine']['print'].get('optional_warnings', True))
|
1536 |
-
|
1537 |
-
with col_v2:
|
1538 |
-
log_level = st.session_state.config['leafmachine']['logging'].get('log_level', None)
|
1539 |
-
log_level_display = log_level if log_level is not None else 'default'
|
1540 |
-
selected_log_level = st.selectbox("Logging Level", ['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'], index=['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'].index(log_level_display))
|
1541 |
-
|
1542 |
-
if selected_log_level == 'default':
|
1543 |
-
st.session_state.config['leafmachine']['logging']['log_level'] = None
|
1544 |
-
else:
|
1545 |
-
st.session_state.config['leafmachine']['logging']['log_level'] = selected_log_level
|
1546 |
-
|
1547 |
-
def content_tab_domain():
|
1548 |
-
st.header('Embeddings Database')
|
1549 |
-
col_emb_1, col_emb_2 = st.columns([4,2])
|
1550 |
-
with col_emb_1:
|
1551 |
-
st.markdown(
|
1552 |
-
"""
|
1553 |
-
VoucherVision includes the option of using domain knowledge inside of the dynamically generated prompts. The OCR text is queried against a database of existing label transcriptions. The most similar existing transcriptions act as an example of what the LLM should emulate and are shown to the LLM as JSON objects. VoucherVision uses cosine similarity search to return the most similar existing transcription.
|
1554 |
-
- Note: Using domain knowledge may increase the chance that foreign text is included in the final transcription
|
1555 |
-
- Disabling this feature will show the LLM multiple examples of an empty JSON skeleton structure instead
|
1556 |
-
- Enabling this option requires a GPU with at least 8GB of VRAM
|
1557 |
-
- The domain knowledge files can be located in the directory "../VoucherVision/domain_knowledge". On first run the embeddings database must be created, which takes time. If the database creation runs each time you use VoucherVision, then something is wrong.
|
1558 |
-
"""
|
1559 |
-
)
|
1560 |
-
|
1561 |
-
st.write(f"Domain Knowledge is only available for the following prompts:")
|
1562 |
-
for available_prompts in PROMPTS_THAT_NEED_DOMAIN_KNOWLEDGE:
|
1563 |
-
st.markdown(f"- {available_prompts}")
|
1564 |
-
|
1565 |
-
if st.session_state.config['leafmachine']['project']['prompt_version'] in PROMPTS_THAT_NEED_DOMAIN_KNOWLEDGE:
|
1566 |
-
st.session_state.config['leafmachine']['project']['use_domain_knowledge'] = st.checkbox("Use domain knowledge", True, disabled=True)
|
1567 |
-
else:
|
1568 |
-
st.session_state.config['leafmachine']['project']['use_domain_knowledge'] = st.checkbox("Use domain knowledge", False, disabled=True)
|
1569 |
-
|
1570 |
-
st.write("")
|
1571 |
-
if st.session_state.config['leafmachine']['project']['use_domain_knowledge']:
|
1572 |
-
st.session_state.config['leafmachine']['project']['embeddings_database_name'] = st.text_input("Embeddings database name (only use underscores)", st.session_state.config['leafmachine']['project'].get('embeddings_database_name', ''))
|
1573 |
-
st.session_state.config['leafmachine']['project']['build_new_embeddings_database'] = st.checkbox("Build *new* embeddings database", st.session_state.config['leafmachine']['project'].get('build_new_embeddings_database', False))
|
1574 |
-
st.session_state.config['leafmachine']['project']['path_to_domain_knowledge_xlsx'] = st.text_input("Path to domain knowledge CSV file (will be used to create new embeddings database)", st.session_state.config['leafmachine']['project'].get('path_to_domain_knowledge_xlsx', ''))
|
1575 |
-
else:
|
1576 |
-
st.session_state.config['leafmachine']['project']['embeddings_database_name'] = st.text_input("Embeddings database name (only use underscores)", st.session_state.config['leafmachine']['project'].get('embeddings_database_name', ''), disabled=True)
|
1577 |
-
st.session_state.config['leafmachine']['project']['build_new_embeddings_database'] = st.checkbox("Build *new* embeddings database", st.session_state.config['leafmachine']['project'].get('build_new_embeddings_database', False), disabled=True)
|
1578 |
-
st.session_state.config['leafmachine']['project']['path_to_domain_knowledge_xlsx'] = st.text_input("Path to domain knowledge CSV file (will be used to create new embeddings database)", st.session_state.config['leafmachine']['project'].get('path_to_domain_knowledge_xlsx', ''), disabled=True)
|
1579 |
-
|
1580 |
-
def render_expense_report_summary():
|
1581 |
-
expense_summary = st.session_state.expense_summary
|
1582 |
-
expense_report = st.session_state.expense_report
|
1583 |
-
st.header('Expense Report Summary')
|
1584 |
-
|
1585 |
-
if expense_summary:
|
1586 |
-
st.metric(label="Total Cost", value=f"${round(expense_summary['total_cost_sum'], 4):,}")
|
1587 |
-
col1, col2 = st.columns(2)
|
1588 |
-
|
1589 |
-
# Run count and total costs
|
1590 |
-
with col1:
|
1591 |
-
st.metric(label="Run Count", value=expense_summary['run_count'])
|
1592 |
-
st.metric(label="Tokens In", value=f"{expense_summary['tokens_in_sum']:,}")
|
1593 |
-
|
1594 |
-
# Token information
|
1595 |
-
with col2:
|
1596 |
-
st.metric(label="Total Images", value=expense_summary['n_images_sum'])
|
1597 |
-
st.metric(label="Tokens Out", value=f"{expense_summary['tokens_out_sum']:,}")
|
1598 |
-
|
1599 |
-
|
1600 |
-
# Calculate cost proportion per image for each API version
|
1601 |
-
st.subheader('Average Cost per Image by API Version')
|
1602 |
-
cost_labels = []
|
1603 |
-
cost_values = []
|
1604 |
-
total_images = 0
|
1605 |
-
cost_per_image_dict = {}
|
1606 |
-
# Iterate through the expense report to accumulate costs and image counts
|
1607 |
-
for index, row in expense_report.iterrows():
|
1608 |
-
api_version = row['api_version']
|
1609 |
-
total_cost = row['total_cost']
|
1610 |
-
n_images = row['n_images']
|
1611 |
-
total_images += n_images # Keep track of total images processed
|
1612 |
-
if api_version not in cost_per_image_dict:
|
1613 |
-
cost_per_image_dict[api_version] = {'total_cost': 0, 'n_images': 0}
|
1614 |
-
cost_per_image_dict[api_version]['total_cost'] += total_cost
|
1615 |
-
cost_per_image_dict[api_version]['n_images'] += n_images
|
1616 |
-
|
1617 |
-
api_versions = list(cost_per_image_dict.keys())
|
1618 |
-
colors = [COLORS_EXPENSE_REPORT[version] if version in COLORS_EXPENSE_REPORT else '#DDDDDD' for version in api_versions]
|
1619 |
-
|
1620 |
-
# Calculate the cost per image for each API version
|
1621 |
-
for version, cost_data in cost_per_image_dict.items():
|
1622 |
-
total_cost = cost_data['total_cost']
|
1623 |
-
n_images = cost_data['n_images']
|
1624 |
-
# Calculate the cost per image for this version
|
1625 |
-
cost_per_image = total_cost / n_images if n_images > 0 else 0
|
1626 |
-
cost_labels.append(version)
|
1627 |
-
cost_values.append(cost_per_image)
|
1628 |
-
# Generate the pie chart
|
1629 |
-
cost_pie_chart = go.Figure(data=[go.Pie(labels=cost_labels, values=cost_values, hole=.3)])
|
1630 |
-
# Update traces for custom text in hoverinfo, displaying cost with a dollar sign and two decimal places
|
1631 |
-
cost_pie_chart.update_traces(
|
1632 |
-
marker=dict(colors=colors),
|
1633 |
-
text=[f"${value:.2f}" for value in cost_values], # Formats the cost as a string with a dollar sign and two decimals
|
1634 |
-
textinfo='percent+label',
|
1635 |
-
hoverinfo='label+percent+text' # Adds custom text (formatted cost) to the hover information
|
1636 |
-
)
|
1637 |
-
st.plotly_chart(cost_pie_chart, use_container_width=True)
|
1638 |
-
|
1639 |
-
st.subheader('Proportion of Total Cost by API Version')
|
1640 |
-
cost_labels = []
|
1641 |
-
cost_proportions = []
|
1642 |
-
total_cost_by_version = {}
|
1643 |
-
# Sum the total cost for each API version
|
1644 |
-
for index, row in expense_report.iterrows():
|
1645 |
-
api_version = row['api_version']
|
1646 |
-
total_cost = row['total_cost']
|
1647 |
-
if api_version not in total_cost_by_version:
|
1648 |
-
total_cost_by_version[api_version] = 0
|
1649 |
-
total_cost_by_version[api_version] += total_cost
|
1650 |
-
# Calculate the combined total cost for all versions
|
1651 |
-
combined_total_cost = sum(total_cost_by_version.values())
|
1652 |
-
# Calculate the proportion of total cost for each API version
|
1653 |
-
for version, total_cost in total_cost_by_version.items():
|
1654 |
-
proportion = (total_cost / combined_total_cost) * 100 if combined_total_cost > 0 else 0
|
1655 |
-
cost_labels.append(version)
|
1656 |
-
cost_proportions.append(proportion)
|
1657 |
-
# Generate the pie chart
|
1658 |
-
cost_pie_chart = go.Figure(data=[go.Pie(labels=cost_labels, values=cost_proportions, hole=.3)])
|
1659 |
-
# Update traces for custom text in hoverinfo
|
1660 |
-
cost_pie_chart.update_traces(
|
1661 |
-
marker=dict(colors=colors),
|
1662 |
-
text=[f"${cost:.2f}" for cost in total_cost_by_version.values()], # This will format the cost to 2 decimal places
|
1663 |
-
textinfo='percent+label',
|
1664 |
-
hoverinfo='label+percent+text' # This tells Plotly to show the label, percent, and custom text (cost) on hover
|
1665 |
-
)
|
1666 |
-
st.plotly_chart(cost_pie_chart, use_container_width=True)
|
1667 |
-
|
1668 |
-
# API version usage percentages pie chart
|
1669 |
-
st.subheader('Runs by API Version')
|
1670 |
-
api_versions = list(expense_summary['api_version_percentages'].keys())
|
1671 |
-
percentages = [expense_summary['api_version_percentages'][version] for version in api_versions]
|
1672 |
-
pie_chart = go.Figure(data=[go.Pie(labels=api_versions, values=percentages, hole=.3)])
|
1673 |
-
pie_chart.update_layout(margin=dict(t=0, b=0, l=0, r=0))
|
1674 |
-
pie_chart.update_traces(marker=dict(colors=colors),)
|
1675 |
-
st.plotly_chart(pie_chart, use_container_width=True)
|
1676 |
-
|
1677 |
-
else:
|
1678 |
-
st.error('No expense report data available.')
|
1679 |
-
|
1680 |
-
def sidebar_content():
|
1681 |
-
if not os.path.exists(os.path.join(st.session_state.dir_home,'expense_report')):
|
1682 |
-
validate_dir(os.path.join(st.session_state.dir_home,'expense_report'))
|
1683 |
-
expense_report_path = os.path.join(st.session_state.dir_home, 'expense_report', 'expense_report.csv')
|
1684 |
|
1685 |
-
if os.path.exists(expense_report_path):
|
1686 |
-
# File exists, proceed with summarization
|
1687 |
-
st.session_state.expense_summary, st.session_state.expense_report = summarize_expense_report(expense_report_path)
|
1688 |
-
render_expense_report_summary()
|
1689 |
-
else:
|
1690 |
-
# File does not exist, handle this case appropriately
|
1691 |
-
# For example, you could set the session state variables to None or an empty value
|
1692 |
-
st.session_state.expense_summary, st.session_state.expense_report = None, None
|
1693 |
-
st.header('Expense Report Summary')
|
1694 |
-
st.write('Available after first run...')
|
1695 |
-
# st.write('Google PaLM 2 is not tracked since it is currently free.')
|
1696 |
|
|
|
|
|
|
|
1697 |
def main():
|
1698 |
with st.sidebar:
|
1699 |
sidebar_content()
|
1700 |
# Main App
|
1701 |
content_header()
|
1702 |
|
1703 |
-
# tab_settings, tab_prompt, tab_domain, tab_component, tab_processing, tab_private, tab_delete = st.tabs(["Project Settings", "Prompt Builder", "Domain Knowledge","Component Detector", "Processing Options", "API Keys", "Space-Saver"])
|
1704 |
-
# tab_settings, tab_prompt, tab_domain, tab_component, tab_processing, tab_delete = st.tabs(["Project Settings", "Prompt Builder", "Domain Knowledge","Component Detector", "Processing Options", "Space-Saver"])
|
1705 |
tab_settings = st.container()
|
1706 |
|
1707 |
with tab_settings:
|
1708 |
content_tab_settings()
|
1709 |
|
1710 |
-
# with tab_prompt:
|
1711 |
-
# if st.button("Build Custom LLM Prompt"):
|
1712 |
-
# st.session_state.proceed_to_build_llm_prompt = True
|
1713 |
-
# st.rerun()
|
1714 |
-
# st.write('When opening the Prompt Builder, it take a moment for the page to refresh.')
|
1715 |
-
|
1716 |
-
# with tab_component:
|
1717 |
-
# # content_tab_component()
|
1718 |
-
# st.markdown("Not available in Hugging Face Spaces implementation. Please use the full [GitHub version](https://github.com/Gene-Weaver/VoucherVision) if you require these features (most use cases do not).")
|
1719 |
-
|
1720 |
-
# with tab_domain:
|
1721 |
-
# # content_tab_domain()
|
1722 |
-
# st.markdown("Not available in Hugging Face Spaces implementation. Please use the full [GitHub version](https://github.com/Gene-Weaver/VoucherVision) if you require these features (most use cases do not).")
|
1723 |
|
1724 |
-
# with tab_processing:
|
1725 |
-
# # content_tab_processing()
|
1726 |
-
# st.markdown("Not available in Hugging Face Spaces implementation. Please use the full [GitHub version](https://github.com/Gene-Weaver/VoucherVision) if you require these features (most use cases do not).")
|
1727 |
|
1728 |
-
|
1729 |
-
|
1730 |
-
|
1731 |
-
|
1732 |
-
|
1733 |
-
# with tab_delete:
|
1734 |
-
# create_space_saver()
|
1735 |
-
# st.markdown("Not available in Hugging Face Spaces implementation. Please use the full [GitHub version](https://github.com/Gene-Weaver/VoucherVision) if you require these features (most use cases do not).")
|
1736 |
|
1737 |
|
1738 |
-
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision')
|
1739 |
|
1740 |
-
|
|
|
|
|
1741 |
if 'config' not in st.session_state:
|
1742 |
st.session_state.config, st.session_state.dir_home = build_VV_config()
|
1743 |
setup_streamlit_config(st.session_state.dir_home)
|
1744 |
|
1745 |
if 'proceed_to_main' not in st.session_state:
|
1746 |
-
st.session_state.proceed_to_main = True
|
1747 |
|
1748 |
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
1749 |
-
st.session_state.proceed_to_build_llm_prompt = False
|
|
|
1750 |
if 'proceed_to_private' not in st.session_state:
|
1751 |
-
st.session_state.proceed_to_private = False
|
1752 |
|
1753 |
if 'dir_uploaded_images' not in st.session_state:
|
1754 |
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
1755 |
validate_dir(os.path.join(st.session_state.dir_home,'uploads'))
|
|
|
1756 |
if 'dir_uploaded_images_small' not in st.session_state:
|
1757 |
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
1758 |
validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
|
1759 |
|
1760 |
-
|
1761 |
-
# Initialize session_state variables if they don't exist
|
1762 |
if 'prompt_info' not in st.session_state:
|
1763 |
st.session_state['prompt_info'] = {}
|
|
|
1764 |
if 'rules' not in st.session_state:
|
1765 |
st.session_state['rules'] = {}
|
|
|
1766 |
if 'zip_filepath' not in st.session_state:
|
1767 |
st.session_state['zip_filepath'] = None
|
|
|
1768 |
if 'input_list' not in st.session_state:
|
1769 |
st.session_state['input_list'] = []
|
|
|
1770 |
if 'input_list_small' not in st.session_state:
|
1771 |
st.session_state['input_list_small'] = []
|
|
|
1772 |
if 'selected_yaml_file' not in st.session_state:
|
1773 |
st.session_state['selected_yaml_file'] = None
|
|
|
1774 |
if 'new_prompt_yaml_filename' not in st.session_state:
|
1775 |
st.session_state['new_prompt_yaml_filename'] = None
|
|
|
1776 |
if 'show_prompt_name_e' not in st.session_state:
|
1777 |
st.session_state['show_prompt_name_e'] = None
|
|
|
1778 |
if 'show_prompt_name_w' not in st.session_state:
|
1779 |
st.session_state['show_prompt_name_w'] = None
|
|
|
1780 |
if 'user_clicked_load_prompt_yaml' not in st.session_state:
|
1781 |
st.session_state['user_clicked_load_prompt_yaml'] = None
|
|
|
1782 |
if 'processing_add_on' not in st.session_state:
|
1783 |
st.session_state['processing_add_on'] = ' 1 Image'
|
1784 |
|
@@ -1801,12 +1274,12 @@ for api_name, key_state in st.session_state['api_name_to_key_state'].items():
|
|
1801 |
if key_state not in st.session_state:
|
1802 |
st.session_state[key_state] = False
|
1803 |
|
1804 |
-
|
1805 |
-
|
|
|
|
|
|
|
1806 |
if st.session_state.proceed_to_build_llm_prompt:
|
1807 |
build_LLM_prompt_config()
|
1808 |
-
elif st.session_state.proceed_to_private:
|
1809 |
-
# create_private_file()
|
1810 |
-
pass
|
1811 |
elif st.session_state.proceed_to_main:
|
1812 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
import yaml, os, json, random, time, shutil
|
|
|
3 |
import plotly.graph_objs as go
|
|
|
4 |
from itertools import chain
|
5 |
from PIL import Image
|
6 |
from io import BytesIO
|
|
|
|
|
|
|
|
|
7 |
from streamlit_extras.let_it_rain import rain
|
8 |
+
|
|
|
|
|
9 |
from vouchervision.LeafMachine2_Config_Builder import write_config_file
|
10 |
+
from vouchervision.VoucherVision_Config_Builder import build_VV_config , TestOptionsGPT, TestOptionsPalm, check_if_usable
|
11 |
+
from vouchervision.vouchervision_main import voucher_vision
|
12 |
+
from vouchervision.general_utils import summarize_expense_report, validate_dir
|
13 |
+
from vouchervision.utils import upload_to_drive, image_to_base64, setup_streamlit_config, save_uploaded_file, check_prompt_yaml_filename
|
14 |
+
|
15 |
+
|
16 |
|
17 |
+
########################################################################################################
|
18 |
+
### Constants ####
|
19 |
+
########################################################################################################
|
20 |
PROMPTS_THAT_NEED_DOMAIN_KNOWLEDGE = ["Version 1","Version 1 PaLM 2"]
|
21 |
# LLM_VERSIONS = ["GPT 4", "GPT 3.5", "Azure GPT 4", "Azure GPT 3.5", "PaLM 2"]
|
22 |
COLORS_EXPENSE_REPORT = {
|
|
|
26 |
}
|
27 |
MAX_GALLERY_IMAGES = 50
|
28 |
GALLERY_IMAGE_SIZE = 128
|
29 |
+
N_OVERALL_STEPS = 6
|
30 |
|
31 |
+
|
32 |
+
|
33 |
+
########################################################################################################
|
34 |
+
### Progress bar ####
|
35 |
+
########################################################################################################
|
36 |
class ProgressReport:
|
37 |
def __init__(self, overall_bar, batch_bar, text_overall, text_batch):
|
38 |
self.overall_bar = overall_bar
|
|
|
71 |
self.overall_bar.progress(0)
|
72 |
self.text_overall.text(step_name)
|
73 |
|
|
|
74 |
def get_n_images(self):
|
75 |
return self.n_images
|
76 |
def get_n_overall(self):
|
77 |
return self.total_overall_steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
|
|
|
|
80 |
|
81 |
+
########################################################################################################
|
82 |
+
### Streamlit helper functions ####
|
83 |
+
########################################################################################################
|
84 |
def display_scrollable_results(JSON_results, test_results, OPT2, OPT3):
|
85 |
"""
|
86 |
Display the results from JSON_results in a scrollable container.
|
|
|
123 |
st.markdown(css, unsafe_allow_html=True)
|
124 |
st.markdown(results_html, unsafe_allow_html=True)
|
125 |
|
126 |
+
|
127 |
+
|
128 |
def display_test_results(test_results, JSON_results, llm_version):
|
129 |
if llm_version == 'gpt':
|
130 |
OPT1, OPT2, OPT3 = TestOptionsGPT.get_options()
|
|
|
162 |
# Close the custom container
|
163 |
st.write('</div>', unsafe_allow_html=True)
|
164 |
|
|
|
165 |
for idx, (test_name, result) in enumerate(sorted(test_results.items())):
|
166 |
_, ind_opt1, ind_opt2, ind_opt3 = test_name.split('__')
|
167 |
opt2_readable = "Use LeafMachine2" if OPT2[int(ind_opt2.split('-')[1])] else "Don't use LeafMachine2"
|
|
|
188 |
# proportional_rain("🥇", success_count, "💔", failure_count, font_size=72, falling_speed=5, animation_length="infinite")
|
189 |
rain_emojis(test_results)
|
190 |
|
191 |
+
|
192 |
+
|
193 |
def add_emoji_delay():
|
194 |
time.sleep(0.3)
|
195 |
|
196 |
+
|
197 |
+
|
198 |
def rain_emojis(test_results):
|
199 |
# test_results = {
|
200 |
# 'test1': True, # Test passed
|
|
|
234 |
)
|
235 |
add_emoji_delay()
|
236 |
|
237 |
+
|
238 |
+
|
239 |
def get_prompt_versions(LLM_version):
|
240 |
yaml_files = [f for f in os.listdir(os.path.join(st.session_state.dir_home, 'custom_prompts')) if f.endswith('.yaml')]
|
241 |
|
|
|
249 |
# Handle other cases or raise an error
|
250 |
return (yaml_files, None)
|
251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
def delete_directory(dir_path):
|
255 |
try:
|
|
|
261 |
st.error(f"Error: {dir_path} : {e.strerror}")
|
262 |
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
+
# Function to load a YAML file and update session_state
|
266 |
+
def load_prompt_yaml(filename):
|
267 |
+
st.session_state['user_clicked_load_prompt_yaml'] = filename
|
268 |
+
with open(filename, 'r') as file:
|
269 |
+
st.session_state['prompt_info'] = yaml.safe_load(file)
|
270 |
+
st.session_state['prompt_author'] = st.session_state['prompt_info'].get('prompt_author', st.session_state['default_prompt_author'])
|
271 |
+
st.session_state['prompt_author_institution'] = st.session_state['prompt_info'].get('prompt_author_institution', st.session_state['default_prompt_author_institution'])
|
272 |
+
st.session_state['prompt_description'] = st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description'])
|
273 |
+
st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'gpt')
|
274 |
+
st.session_state['instructions'] = st.session_state['prompt_info'].get('instructions', st.session_state['default_instructions'])
|
275 |
+
st.session_state['json_formatting_instructions'] = st.session_state['prompt_info'].get('json_formatting_instructions', st.session_state['default_json_formatting_instructions'] )
|
276 |
+
st.session_state['rules'] = st.session_state['prompt_info'].get('rules', {})
|
277 |
+
st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
+
st.session_state['prompt_info'] = {
|
280 |
+
'prompt_author': st.session_state['prompt_author'],
|
281 |
+
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
282 |
+
'prompt_description': st.session_state['prompt_description'],
|
283 |
+
'LLM': st.session_state['LLM'],
|
284 |
+
'instructions': st.session_state['instructions'],
|
285 |
+
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
286 |
+
'rules': st.session_state['rules'],
|
287 |
+
'mapping': st.session_state['mapping'],
|
288 |
+
}
|
289 |
|
290 |
+
# Placeholder:
|
291 |
+
st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
+
def save_prompt_yaml(filename, col_right_save):
|
296 |
+
yaml_content = {
|
297 |
+
'prompt_author': st.session_state['prompt_author'],
|
298 |
+
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
299 |
+
'prompt_description': st.session_state['prompt_description'],
|
300 |
+
'LLM': st.session_state['LLM'],
|
301 |
+
'instructions': st.session_state['instructions'],
|
302 |
+
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
303 |
+
'rules': st.session_state['rules'],
|
304 |
+
'mapping': st.session_state['mapping'],
|
305 |
+
}
|
306 |
|
307 |
+
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
308 |
+
filepath = os.path.join(dir_prompt, f"{filename}.yaml")
|
309 |
+
|
310 |
+
with open(filepath, 'w') as file:
|
311 |
+
yaml.safe_dump(dict(yaml_content), file, sort_keys=False)
|
312 |
+
|
313 |
+
st.success(f"Prompt saved as '{filename}.yaml'.")
|
314 |
+
|
315 |
+
upload_to_drive(filepath, filename)
|
316 |
+
|
317 |
+
with col_right_save:
|
318 |
+
create_download_button_yaml(filepath, filename)
|
319 |
|
320 |
+
|
321 |
+
|
322 |
+
def check_unique_mapping_assignments():
|
323 |
+
if len(st.session_state['assigned_columns']) != len(set(st.session_state['assigned_columns'])):
|
324 |
+
st.error("Each column name must be assigned to only one category.")
|
325 |
+
return False
|
326 |
+
else:
|
327 |
+
st.success("Mapping confirmed.")
|
328 |
+
return True
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
def create_download_button(zip_filepath):
|
333 |
+
with open(zip_filepath, 'rb') as f:
|
334 |
+
bytes_io = BytesIO(f.read())
|
335 |
+
st.download_button(
|
336 |
+
label=f"Download Results for{st.session_state['processing_add_on']}",type='primary',
|
337 |
+
data=bytes_io,
|
338 |
+
file_name=os.path.basename(zip_filepath),
|
339 |
+
mime='application/zip'
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
def btn_load_prompt(selected_yaml_file, dir_prompt):
|
345 |
+
if selected_yaml_file:
|
346 |
+
yaml_file_path = os.path.join(dir_prompt, selected_yaml_file)
|
347 |
+
load_prompt_yaml(yaml_file_path)
|
348 |
+
elif not selected_yaml_file:
|
349 |
+
# Directly assigning default values since no file is selected
|
350 |
+
st.session_state['prompt_info'] = {}
|
351 |
+
st.session_state['prompt_author'] = st.session_state['default_prompt_author']
|
352 |
+
st.session_state['prompt_author_institution'] = st.session_state['default_prompt_author_institution']
|
353 |
+
st.session_state['prompt_description'] = st.session_state['default_prompt_description']
|
354 |
+
st.session_state['instructions'] = st.session_state['default_instructions']
|
355 |
+
st.session_state['json_formatting_instructions'] = st.session_state['default_json_formatting_instructions']
|
356 |
+
st.session_state['rules'] = {}
|
357 |
+
st.session_state['LLM'] = 'gpt'
|
358 |
+
|
359 |
+
st.session_state['assigned_columns'] = []
|
360 |
+
|
361 |
+
st.session_state['prompt_info'] = {
|
362 |
+
'prompt_author': st.session_state['prompt_author'],
|
363 |
+
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
364 |
+
'prompt_description': st.session_state['prompt_description'],
|
365 |
+
'LLM': st.session_state['LLM'],
|
366 |
+
'instructions': st.session_state['instructions'],
|
367 |
+
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
368 |
+
'rules': st.session_state['rules'],
|
369 |
+
'mapping': st.session_state['mapping'],
|
370 |
+
}
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
def refresh():
|
375 |
+
st.write('')
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
def upload_local_prompt_to_server(dir_prompt):
|
380 |
+
uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
|
381 |
+
if uploaded_file is not None:
|
382 |
+
# Check the file extension
|
383 |
+
file_name = uploaded_file.name
|
384 |
+
if file_name.endswith('.yaml'):
|
385 |
+
file_path = os.path.join(dir_prompt, file_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
+
# Save the file
|
388 |
+
with open(file_path, 'wb') as f:
|
389 |
+
f.write(uploaded_file.getbuffer())
|
390 |
+
st.success(f"Saved file {file_name} in {dir_prompt}")
|
391 |
+
else:
|
392 |
+
st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
def create_download_button_yaml(file_path, selected_yaml_file):
|
397 |
+
file_label = f"Download {selected_yaml_file}"
|
398 |
+
with open(file_path, 'rb') as f:
|
399 |
+
st.download_button(
|
400 |
+
label=file_label,
|
401 |
+
data=f,
|
402 |
+
file_name=os.path.basename(file_path),
|
403 |
+
mime='application/x-yaml'
|
404 |
+
)
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
def clear_image_gallery():
|
409 |
+
delete_directory(st.session_state['dir_uploaded_images'])
|
410 |
+
delete_directory(st.session_state['dir_uploaded_images_small'])
|
411 |
+
validate_dir(st.session_state['dir_uploaded_images'])
|
412 |
+
validate_dir(st.session_state['dir_uploaded_images_small'])
|
413 |
+
|
414 |
+
|
415 |
+
|
416 |
+
def use_test_image():
|
417 |
+
st.info(f"Processing images from {os.path.join(st.session_state.dir_home,'demo','demo_images')}")
|
418 |
+
st.session_state.config['leafmachine']['project']['dir_images_local'] = os.path.join(st.session_state.dir_home,'demo','demo_images')
|
419 |
+
n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
|
420 |
+
st.session_state['processing_add_on'] = f" {n_images} Images"
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
|
425 |
+
########################################################################################################
|
426 |
+
### Streamlit sections ####
|
427 |
+
########################################################################################################
|
428 |
+
def create_space_saver():
|
429 |
+
st.subheader("Space Saving Options")
|
430 |
+
col_ss_1, col_ss_2 = st.columns([2,2])
|
431 |
+
with col_ss_1:
|
432 |
+
st.write("Several folders are created and populated with data during the VoucherVision transcription process.")
|
433 |
+
st.write("Below are several options that will allow you to automatically delete temporary files that you may not need for everyday operations.")
|
434 |
+
st.write("VoucherVision creates the following folders. Folders marked with a :star: are required if you want to use VoucherVisionEditor for quality control.")
|
435 |
+
st.write("`../[Run Name]/Archival_Components`")
|
436 |
+
st.write("`../[Run Name]/Config_File`")
|
437 |
+
st.write("`../[Run Name]/Cropped_Images` :star:")
|
438 |
+
st.write("`../[Run Name]/Logs`")
|
439 |
+
st.write("`../[Run Name]/Original_Images` :star:")
|
440 |
+
st.write("`../[Run Name]/Transcription` :star:")
|
441 |
+
with col_ss_2:
|
442 |
+
st.session_state.config['leafmachine']['project']['delete_temps_keep_VVE'] = st.checkbox("Delete Temporary Files (KEEP files required for VoucherVisionEditor)", st.session_state.config['leafmachine']['project'].get('delete_temps_keep_VVE', False))
|
443 |
+
st.session_state.config['leafmachine']['project']['delete_all_temps'] = st.checkbox("Keep only the final transcription file", st.session_state.config['leafmachine']['project'].get('delete_all_temps', False),help="*WARNING:* This limits your ability to do quality assurance. This will delete all folders created by VoucherVision, leaving only the `transcription.xlsx` file.")
|
444 |
+
|
445 |
+
|
446 |
|
447 |
def show_available_APIs():
|
448 |
st.session_state['has_key_openai'] = (os.getenv('OPENAI_API_KEY') is not None) and (os.getenv('OPENAI_API_KEY') != '')
|
|
|
513 |
# Apply the CSS
|
514 |
st.markdown(css, unsafe_allow_html=True)
|
515 |
|
516 |
+
def show_header_welcome():
|
517 |
+
st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
|
518 |
+
st.session_state.logo = Image.open(st.session_state.logo_path)
|
519 |
+
st.image(st.session_state.logo, width=250)
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
|
|
521 |
|
|
|
|
|
|
|
|
|
522 |
|
523 |
+
########################################################################################################
|
524 |
+
### Sidebar for Expense Report ####
|
525 |
+
########################################################################################################
|
526 |
+
def render_expense_report_summary():
|
527 |
+
cost_labels = []
|
528 |
+
cost_values = []
|
529 |
+
total_images = 0
|
530 |
+
cost_per_image_dict = {}
|
531 |
+
st.header('Expense Report Summary')
|
|
|
|
|
|
|
|
|
532 |
|
533 |
+
if st.session_state.expense_summary:
|
534 |
+
st.metric(label="Total Cost", value=f"${round(st.session_state.expense_summary['total_cost_sum'], 4):,}")
|
535 |
+
col1, col2 = st.columns(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
|
537 |
+
# Run count and total costs
|
538 |
+
with col1:
|
539 |
+
st.metric(label="Run Count", value=st.session_state.expense_summary['run_count'])
|
540 |
+
st.metric(label="Tokens In", value=f"{st.session_state.expense_summary['tokens_in_sum']:,}")
|
541 |
|
542 |
+
# Token information
|
543 |
+
with col2:
|
544 |
+
st.metric(label="Total Images", value=st.session_state.expense_summary['n_images_sum'])
|
545 |
+
st.metric(label="Tokens Out", value=f"{st.session_state.expense_summary['tokens_out_sum']:,}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
|
547 |
+
# Calculate cost proportion per image for each API version
|
548 |
+
st.subheader('Average Cost per Image by API Version')
|
549 |
+
|
550 |
+
# Iterate through the expense report to accumulate costs and image counts
|
551 |
+
for index, row in st.session_state.expense_report.iterrows():
|
552 |
+
api_version = row['api_version']
|
553 |
+
total_cost = row['total_cost']
|
554 |
+
n_images = row['n_images']
|
555 |
+
total_images += n_images # Keep track of total images processed
|
556 |
+
if api_version not in cost_per_image_dict:
|
557 |
+
cost_per_image_dict[api_version] = {'total_cost': 0, 'n_images': 0}
|
558 |
+
cost_per_image_dict[api_version]['total_cost'] += total_cost
|
559 |
+
cost_per_image_dict[api_version]['n_images'] += n_images
|
560 |
|
561 |
+
api_versions = list(cost_per_image_dict.keys())
|
562 |
+
colors = [COLORS_EXPENSE_REPORT[version] if version in COLORS_EXPENSE_REPORT else '#DDDDDD' for version in api_versions]
|
563 |
+
|
564 |
+
# Calculate the cost per image for each API version
|
565 |
+
for version, cost_data in cost_per_image_dict.items():
|
566 |
+
total_cost = cost_data['total_cost']
|
567 |
+
n_images = cost_data['n_images']
|
568 |
|
569 |
+
# Calculate the cost per image for this version
|
570 |
+
cost_per_image = total_cost / n_images if n_images > 0 else 0
|
571 |
+
cost_labels.append(version)
|
572 |
+
cost_values.append(cost_per_image)
|
573 |
|
574 |
+
# Generate the pie chart
|
575 |
+
cost_pie_chart = go.Figure(data=[go.Pie(labels=cost_labels, values=cost_values, hole=.3)])
|
576 |
|
577 |
+
# Update traces for custom text in hoverinfo, displaying cost with a dollar sign and two decimal places
|
578 |
+
cost_pie_chart.update_traces(
|
579 |
+
marker=dict(colors=colors),
|
580 |
+
text=[f"${value:.2f}" for value in cost_values],
|
581 |
+
textinfo='percent+label',
|
582 |
+
hoverinfo='label+percent+text'
|
|
|
583 |
)
|
584 |
+
st.plotly_chart(cost_pie_chart, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
+
st.subheader('Proportion of Total Cost by API Version')
|
587 |
+
cost_labels = []
|
588 |
+
cost_proportions = []
|
589 |
+
total_cost_by_version = {}
|
|
|
590 |
|
591 |
+
# Sum the total cost for each API version
|
592 |
+
for index, row in st.session_state.expense_report.iterrows():
|
593 |
+
api_version = row['api_version']
|
594 |
+
total_cost = row['total_cost']
|
595 |
+
if api_version not in total_cost_by_version:
|
596 |
+
total_cost_by_version[api_version] = 0
|
597 |
+
total_cost_by_version[api_version] += total_cost
|
598 |
+
|
599 |
+
# Calculate the combined total cost for all versions
|
600 |
+
combined_total_cost = sum(total_cost_by_version.values())
|
601 |
|
602 |
+
# Calculate the proportion of total cost for each API version
|
603 |
+
for version, total_cost in total_cost_by_version.items():
|
604 |
+
proportion = (total_cost / combined_total_cost) * 100 if combined_total_cost > 0 else 0
|
605 |
+
cost_labels.append(version)
|
606 |
+
cost_proportions.append(proportion)
|
|
|
607 |
|
608 |
+
# Generate the pie chart
|
609 |
+
cost_pie_chart = go.Figure(data=[go.Pie(labels=cost_labels, values=cost_proportions, hole=.3)])
|
|
|
|
|
610 |
|
611 |
+
# Update traces for custom text in hoverinfo
|
612 |
+
cost_pie_chart.update_traces(
|
613 |
+
marker=dict(colors=colors),
|
614 |
+
text=[f"${cost:.2f}" for cost in total_cost_by_version.values()],
|
615 |
+
textinfo='percent+label',
|
616 |
+
hoverinfo='label+percent+text'
|
617 |
+
)
|
618 |
+
st.plotly_chart(cost_pie_chart, use_container_width=True)
|
|
|
619 |
|
620 |
+
# API version usage percentages pie chart
|
621 |
+
st.subheader('Runs by API Version')
|
622 |
+
api_versions = list(st.session_state.expense_summary['api_version_percentages'].keys())
|
623 |
+
percentages = [st.session_state.expense_summary['api_version_percentages'][version] for version in api_versions]
|
624 |
+
pie_chart = go.Figure(data=[go.Pie(labels=api_versions, values=percentages, hole=.3)])
|
625 |
+
pie_chart.update_layout(margin=dict(t=0, b=0, l=0, r=0))
|
626 |
+
pie_chart.update_traces(marker=dict(colors=colors),)
|
627 |
+
st.plotly_chart(pie_chart, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
|
629 |
+
else:
|
630 |
+
st.error('No expense report data available.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
|
|
|
|
|
632 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
+
def sidebar_content():
|
635 |
+
if not os.path.exists(os.path.join(st.session_state.dir_home,'expense_report')):
|
636 |
+
validate_dir(os.path.join(st.session_state.dir_home,'expense_report'))
|
637 |
+
expense_report_path = os.path.join(st.session_state.dir_home, 'expense_report', 'expense_report.csv')
|
638 |
+
|
639 |
+
if os.path.exists(expense_report_path):
|
640 |
+
# File exists, proceed with summarization
|
641 |
+
st.session_state.expense_summary, st.session_state.expense_report = summarize_expense_report(expense_report_path)
|
642 |
+
render_expense_report_summary()
|
643 |
+
else:
|
644 |
+
st.session_state.expense_summary, st.session_state.expense_report = None, None
|
645 |
+
st.header('Expense Report Summary')
|
646 |
+
st.write('Available after first run...')
|
647 |
+
|
648 |
+
|
649 |
+
|
650 |
+
########################################################################################################
|
651 |
+
### Config Builder ####
|
652 |
+
########################################################################################################
|
653 |
def build_LLM_prompt_config():
|
654 |
st.session_state['assigned_columns'] = []
|
655 |
st.session_state['default_prompt_author'] = 'unknown'
|
|
|
679 |
|
680 |
# Start building the Streamlit app
|
681 |
col_prompt_main_left, ___, col_prompt_main_right = st.columns([6,1,3])
|
|
|
682 |
|
683 |
with col_prompt_main_left:
|
|
|
684 |
st.title("Custom LLM Prompt Builder")
|
685 |
st.subheader('About')
|
686 |
st.write("This form allows you to craft a prompt for your specific task.")
|
|
|
695 |
st.write("4. When you go back the main VoucherVision page you will now see your custom prompt available in the 'Prompt Version' dropdown menu.")
|
696 |
st.write("5. Select your custom prompt. Note, your prompt will only be available for the LLM that you set when filling out the form below.")
|
697 |
|
|
|
698 |
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
699 |
yaml_files = [f for f in os.listdir(dir_prompt) if f.endswith('.yaml')]
|
700 |
col_upload_yaml, col_upload_yaml_2 = st.columns([4,4])
|
|
|
719 |
# Create the download button
|
720 |
st.write('##')
|
721 |
create_download_button_yaml(download_file_path, st.session_state['selected_yaml_file'] )
|
|
|
722 |
|
723 |
# Prompt Author Information
|
724 |
st.header("Prompt Author Information")
|
|
|
730 |
|
731 |
st.write("Please provide a description of your prompt and its intended task. Is it designed for a specific collection? Taxa? Database structure?")
|
732 |
st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description']))
|
|
|
733 |
|
|
|
734 |
# Input for new file name
|
735 |
+
st.write('---')
|
736 |
st.header("Prompt Name")
|
737 |
st.write('Provide a name for your custom prompt. It can only conatin letters, numbers, and underscores. No spaces, dashes, or special characters.')
|
738 |
st.session_state['new_prompt_yaml_filename'] = st.text_input("Enter filename to save your prompt as a configuration YAML:", value=None, placeholder='my_prompt_name')
|
739 |
|
740 |
+
# Define the options for the LLM Model Type dropdown
|
741 |
st.write('---')
|
742 |
st.header("Set LLM Model Type")
|
|
|
743 |
llm_options = ['gpt', 'palm']
|
744 |
# Create the dropdown and set the value to session_state['LLM']
|
745 |
st.write("Which LLM is this prompt designed for? This will not restrict its use to a specific LLM, but some prompts will behave in different ways across models.")
|
|
|
760 |
st.write("The following section tells the LLM how we want to structure the JSON dictionary. We do not recommend changing this section because it would likely result in unstable and inconsistent behavior.")
|
761 |
st.session_state['json_formatting_instructions'] = st.text_area("Enter column instructions:", value=st.session_state['default_json_formatting_instructions'], height=350, disabled=True)
|
762 |
|
|
|
|
|
|
|
|
|
763 |
st.write('---')
|
764 |
col_left, col_right = st.columns([6,4])
|
765 |
with col_left:
|
|
|
778 |
"taxonomy": ["Genus_species"]
|
779 |
}
|
780 |
|
|
|
|
|
|
|
781 |
new_column_name = st.text_input("Enter a new column name:")
|
782 |
+
|
|
|
783 |
if st.button("Add New Column") and new_column_name:
|
784 |
if new_column_name not in st.session_state['rules']['Dictionary']:
|
785 |
st.session_state['rules']['Dictionary'][new_column_name] = {"format": "", "null_value": "", "description": ""}
|
|
|
802 |
if 'selected_column' not in st.session_state:
|
803 |
st.session_state['selected_column'] = column_name
|
804 |
|
|
|
|
|
|
|
805 |
# Form for input fields
|
806 |
with st.form(key='rule_form'):
|
807 |
format_options = ["verbatim transcription", "spell check transcription", "boolean yes no", "boolean 1 0", "integer", "[list]", "yyyy-mm-dd"]
|
|
|
830 |
# Force the form to reset by clearing the fields from the session state
|
831 |
st.session_state.pop('selected_column', None) # Clear the selected column to force reset
|
832 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
# Layout for removing an existing column
|
|
|
|
|
834 |
delete_column_name = st.selectbox("Select a column to delete:", [""] + editable_columns, key='delete_column')
|
|
|
|
|
835 |
if st.button("Delete Column") and delete_column_name:
|
836 |
del st.session_state['rules'][delete_column_name]
|
837 |
st.success(f"Column '{delete_column_name}' removed from rules.")
|
838 |
|
|
|
|
|
|
|
839 |
with col_right:
|
840 |
# Display the current state of the JSON rules
|
841 |
st.subheader('Formatted Columns')
|
842 |
st.json(st.session_state['rules']['Dictionary'])
|
843 |
|
|
|
|
|
|
|
|
|
844 |
st.write('---')
|
|
|
|
|
845 |
col_left_mapping, col_right_mapping = st.columns([6,4])
|
846 |
with col_left_mapping:
|
847 |
st.header("Mapping")
|
|
|
899 |
st.subheader('Formatted Column Maps')
|
900 |
st.json(st.session_state['mapping'])
|
901 |
|
|
|
902 |
st.write('---')
|
903 |
st.header("Save and Download Custom Prompt")
|
904 |
st.write('Once you click save, validation checks will verify the formatting and then a download button will appear so that you can ***save a local copy of your custom prompt.***')
|
|
|
954 |
}
|
955 |
st.json(st.session_state['prompt_info'])
|
956 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
957 |
def content_header():
|
958 |
+
# Header section, run, quick start, API report
|
959 |
col_run_1, col_run_2, col_run_3, col_run_4 = st.columns([2,2,2,2])
|
960 |
|
961 |
+
# Progress bar
|
|
|
962 |
col_run_info_1 = st.columns([1])[0]
|
963 |
|
|
|
|
|
|
|
964 |
with col_run_info_1:
|
965 |
# Progress
|
966 |
+
st.subheader("Overall Progress")
|
967 |
overall_progress_bar = st.progress(0)
|
968 |
text_overall = st.empty() # Placeholder for current step name
|
969 |
+
|
970 |
st.subheader('Transcription Progress')
|
971 |
batch_progress_bar = st.progress(0)
|
972 |
text_batch = st.empty() # Placeholder for current step name
|
973 |
+
|
974 |
progress_report = ProgressReport(overall_progress_bar, batch_progress_bar, text_overall, text_batch)
|
975 |
+
|
976 |
st.info("***Note:*** There is a known bug with tabs in Streamlit. If you update an input field it may take you back to the 'Project Settings' tab. Changes that you made are saved, it's just an annoying glitch. We are aware of this issue and will fix it as soon as we can.")
|
977 |
st.write("If you use VoucherVision frequently, you can change the default values that are auto-populated in the form below. In a text editor or IDE, edit the first few rows in the file `../VoucherVision/vouchervision/VoucherVision_Config_Builder.py`")
|
|
|
978 |
|
979 |
with col_run_1:
|
980 |
show_header_welcome()
|
981 |
st.subheader('Run VoucherVision')
|
|
|
982 |
|
983 |
+
if not check_if_usable():
|
984 |
+
st.button("Start Processing", type='primary', disabled=True)
|
985 |
+
# st.error(":heavy_exclamation_mark: Required API keys not set. Please visit the 'API Keys' tab and set the Google Vision OCR API key and at least one LLM key.")
|
986 |
+
st.error(":heavy_exclamation_mark: Required API keys not set. Please set the API keys as 'Secrets' for your Hugging Face Space. Visit the 'Settings' tab at the top of the page.")
|
987 |
+
else:
|
988 |
if st.button(f"Start Processing{st.session_state['processing_add_on']}", type='primary'):
|
989 |
|
990 |
# First, write the config file.
|
|
|
993 |
path_custom_prompts = os.path.join(st.session_state.dir_home,'custom_prompts',st.session_state.config['leafmachine']['project']['prompt_version'])
|
994 |
|
995 |
# Define number of overall steps
|
996 |
+
progress_report.set_n_overall(N_OVERALL_STEPS)
|
997 |
progress_report.update_overall(f"Starting VoucherVision...")
|
998 |
|
999 |
# Call the machine function.
|
|
|
1015 |
|
1016 |
if st.session_state['zip_filepath']:
|
1017 |
create_download_button(st.session_state['zip_filepath'])
|
|
|
|
|
|
|
|
|
|
|
|
|
1018 |
st.button("Refresh", on_click=refresh)
|
1019 |
|
1020 |
with col_run_2:
|
|
|
1031 |
st.write('8. Start processing --- Wait for VoucherVision to finish.')
|
1032 |
st.write('9. Download results --- Click the "Download Results" button to save the VoucherVision output to your computer. ***Output files will disappear if you start a new run or restart the Space.***')
|
1033 |
st.write('10. Editing the LLM transcriptions --- Use the VoucherVisionEditor to revise and correct any mistakes or ommissions.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1034 |
|
1035 |
with col_run_4:
|
1036 |
st.subheader('Available LLMs and APIs')
|
1037 |
show_available_APIs()
|
1038 |
st.info('Until the end of 2023, Azure OpenAI models will be available for anyone to use here. Then only PaLM 2 will be available. To use all services, duplicate this Space and provide your own API keys.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1039 |
|
|
|
|
|
|
|
|
|
|
|
1040 |
|
1041 |
+
|
1042 |
+
|
1043 |
+
########################################################################################################
|
1044 |
+
### Main Settings ####
|
1045 |
+
########################################################################################################
|
|
|
1046 |
def content_tab_settings():
|
1047 |
+
st.write("---")
|
1048 |
+
st.header("Configuration Settings")
|
1049 |
col_project_1, col_project_2, col_project_3 = st.columns([2,2,2])
|
1050 |
|
1051 |
st.write("---")
|
|
|
1066 |
st.subheader('Run name')
|
1067 |
st.session_state.config['leafmachine']['project']['run_name'] = st.text_input("Run name", st.session_state.config['leafmachine']['project'].get('run_name', ''),
|
1068 |
label_visibility='collapsed')
|
|
|
1069 |
st.write("Run name will be the name of the final zipped folder.")
|
1070 |
|
1071 |
### LLM Version
|
|
|
1107 |
if st.button("Build Custom LLM Prompt",help="It may take a moment for the page to refresh."):
|
1108 |
st.session_state.proceed_to_build_llm_prompt = True
|
1109 |
st.rerun()
|
1110 |
+
|
1111 |
### Input Images Local
|
1112 |
with col_local_1:
|
1113 |
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
|
|
1138 |
|
1139 |
st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
|
1140 |
|
1141 |
+
# Show uploaded images gallery (thumbnails only)
|
1142 |
with col_local_2:
|
1143 |
if st.session_state['input_list_small']:
|
1144 |
st.subheader('Image Gallery')
|
|
|
1149 |
# If there are less than 100 images, take them all
|
1150 |
images_to_display = st.session_state['input_list_small']
|
1151 |
st.image(images_to_display)
|
1152 |
+
|
|
|
|
|
|
|
|
|
1153 |
with col_cropped_1:
|
1154 |
default_crops = st.session_state.config['leafmachine']['cropped_components'].get('save_cropped_annotations', ['leaf_whole'])
|
1155 |
st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. (Requires GPU)")
|
|
|
1173 |
image_ocr = Image.open(ocr)
|
1174 |
st.image(image_ocr, caption='OCR Overlay Images', output_format = "PNG")
|
1175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1177 |
|
1178 |
+
########################################################################################################
|
1179 |
+
### Main ####
|
1180 |
+
########################################################################################################
|
1181 |
def main():
|
1182 |
with st.sidebar:
|
1183 |
sidebar_content()
|
1184 |
# Main App
|
1185 |
content_header()
|
1186 |
|
|
|
|
|
1187 |
tab_settings = st.container()
|
1188 |
|
1189 |
with tab_settings:
|
1190 |
content_tab_settings()
|
1191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1192 |
|
|
|
|
|
|
|
1193 |
|
1194 |
+
########################################################################################################
|
1195 |
+
### STREAMLIT APP START ####
|
1196 |
+
########################################################################################################
|
1197 |
+
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision')
|
|
|
|
|
|
|
|
|
1198 |
|
1199 |
|
|
|
1200 |
|
1201 |
+
########################################################################################################
|
1202 |
+
### STREAMLIT INIT STATES ####
|
1203 |
+
########################################################################################################
|
1204 |
if 'config' not in st.session_state:
|
1205 |
st.session_state.config, st.session_state.dir_home = build_VV_config()
|
1206 |
setup_streamlit_config(st.session_state.dir_home)
|
1207 |
|
1208 |
if 'proceed_to_main' not in st.session_state:
|
1209 |
+
st.session_state.proceed_to_main = True
|
1210 |
|
1211 |
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
1212 |
+
st.session_state.proceed_to_build_llm_prompt = False
|
1213 |
+
|
1214 |
if 'proceed_to_private' not in st.session_state:
|
1215 |
+
st.session_state.proceed_to_private = False
|
1216 |
|
1217 |
if 'dir_uploaded_images' not in st.session_state:
|
1218 |
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
1219 |
validate_dir(os.path.join(st.session_state.dir_home,'uploads'))
|
1220 |
+
|
1221 |
if 'dir_uploaded_images_small' not in st.session_state:
|
1222 |
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
1223 |
validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
|
1224 |
|
|
|
|
|
1225 |
if 'prompt_info' not in st.session_state:
|
1226 |
st.session_state['prompt_info'] = {}
|
1227 |
+
|
1228 |
if 'rules' not in st.session_state:
|
1229 |
st.session_state['rules'] = {}
|
1230 |
+
|
1231 |
if 'zip_filepath' not in st.session_state:
|
1232 |
st.session_state['zip_filepath'] = None
|
1233 |
+
|
1234 |
if 'input_list' not in st.session_state:
|
1235 |
st.session_state['input_list'] = []
|
1236 |
+
|
1237 |
if 'input_list_small' not in st.session_state:
|
1238 |
st.session_state['input_list_small'] = []
|
1239 |
+
|
1240 |
if 'selected_yaml_file' not in st.session_state:
|
1241 |
st.session_state['selected_yaml_file'] = None
|
1242 |
+
|
1243 |
if 'new_prompt_yaml_filename' not in st.session_state:
|
1244 |
st.session_state['new_prompt_yaml_filename'] = None
|
1245 |
+
|
1246 |
if 'show_prompt_name_e' not in st.session_state:
|
1247 |
st.session_state['show_prompt_name_e'] = None
|
1248 |
+
|
1249 |
if 'show_prompt_name_w' not in st.session_state:
|
1250 |
st.session_state['show_prompt_name_w'] = None
|
1251 |
+
|
1252 |
if 'user_clicked_load_prompt_yaml' not in st.session_state:
|
1253 |
st.session_state['user_clicked_load_prompt_yaml'] = None
|
1254 |
+
|
1255 |
if 'processing_add_on' not in st.session_state:
|
1256 |
st.session_state['processing_add_on'] = ' 1 Image'
|
1257 |
|
|
|
1274 |
if key_state not in st.session_state:
|
1275 |
st.session_state[key_state] = False
|
1276 |
|
1277 |
+
|
1278 |
+
|
1279 |
+
########################################################################################################
|
1280 |
+
### STREAMLIT SESSION GUIDE ####
|
1281 |
+
########################################################################################################
|
1282 |
if st.session_state.proceed_to_build_llm_prompt:
|
1283 |
build_LLM_prompt_config()
|
|
|
|
|
|
|
1284 |
elif st.session_state.proceed_to_main:
|
1285 |
main()
|
vouchervision/utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, re
|
2 |
+
from googleapiclient.discovery import build
|
3 |
+
from googleapiclient.http import MediaFileUpload
|
4 |
+
from google.oauth2 import service_account
|
5 |
+
import base64
|
6 |
+
from PIL import Image
|
7 |
+
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
|
10 |
+
from vouchervision.general_utils import get_cfg_from_full_path
|
11 |
+
|
12 |
+
|
13 |
+
def setup_streamlit_config(dir_home):
|
14 |
+
# Define the directory path and filename
|
15 |
+
dir_path = os.path.join(dir_home, ".streamlit")
|
16 |
+
file_path = os.path.join(dir_path, "config.toml")
|
17 |
+
|
18 |
+
# Check if directory exists, if not create it
|
19 |
+
if not os.path.exists(dir_path):
|
20 |
+
os.makedirs(dir_path)
|
21 |
+
|
22 |
+
# Create or modify the file with the provided content
|
23 |
+
config_content = f"""
|
24 |
+
[theme]
|
25 |
+
base = "dark"
|
26 |
+
primaryColor = "#00ff00"
|
27 |
+
|
28 |
+
[server]
|
29 |
+
enableStaticServing = false
|
30 |
+
runOnSave = true
|
31 |
+
port = 8524
|
32 |
+
maxUploadSize = 5000
|
33 |
+
"""
|
34 |
+
|
35 |
+
with open(file_path, "w") as f:
|
36 |
+
f.write(config_content.strip())
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def save_uploaded_file(directory, img_file, image=None):
|
41 |
+
if not os.path.exists(directory):
|
42 |
+
os.makedirs(directory)
|
43 |
+
# Assuming the uploaded file is an image
|
44 |
+
if image is None:
|
45 |
+
with Image.open(img_file) as image:
|
46 |
+
full_path = os.path.join(directory, img_file.name)
|
47 |
+
image.save(full_path, "JPEG")
|
48 |
+
# Return the full path of the saved image
|
49 |
+
return full_path
|
50 |
+
else:
|
51 |
+
full_path = os.path.join(directory, img_file.name)
|
52 |
+
image.save(full_path, "JPEG")
|
53 |
+
return full_path
|
54 |
+
|
55 |
+
def image_to_base64(img):
|
56 |
+
buffered = BytesIO()
|
57 |
+
img.save(buffered, format="JPEG")
|
58 |
+
return base64.b64encode(buffered.getvalue()).decode()
|
59 |
+
|
60 |
+
def check_prompt_yaml_filename(fname):
|
61 |
+
# Check if the filename only contains letters, numbers, underscores, and dashes
|
62 |
+
pattern = r'^[\w-]+$'
|
63 |
+
|
64 |
+
# The \w matches any alphanumeric character and is equivalent to the character class [a-zA-Z0-9_].
|
65 |
+
# The hyphen - is literally matched.
|
66 |
+
|
67 |
+
if re.match(pattern, fname):
|
68 |
+
return True
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
# Function to upload files to Google Drive
|
73 |
+
def upload_to_drive(filepath, filename):
|
74 |
+
# Parse the service account info from the environment variable
|
75 |
+
creds_info = json.loads(os.environ.get('GDRIVE_API'))
|
76 |
+
if creds_info:
|
77 |
+
creds = service_account.Credentials.from_service_account_info(
|
78 |
+
creds_info, scopes=["https://www.googleapis.com/auth/drive"]
|
79 |
+
)
|
80 |
+
service = build('drive', 'v3', credentials=creds)
|
81 |
+
|
82 |
+
# Get the folder ID from the environment variable
|
83 |
+
folder_id = os.environ.get('GDRIVE')
|
84 |
+
# st.info(f"{folder_id}")
|
85 |
+
|
86 |
+
if folder_id:
|
87 |
+
file_metadata = {
|
88 |
+
'name': filename,
|
89 |
+
'parents': [folder_id]
|
90 |
+
}
|
91 |
+
# st.info(f"{file_metadata}")
|
92 |
+
|
93 |
+
media = MediaFileUpload(filepath, mimetype='application/x-yaml')
|
94 |
+
|
95 |
+
service.files().create(
|
96 |
+
body=file_metadata,
|
97 |
+
media_body=media,
|
98 |
+
fields='id'
|
99 |
+
).execute()
|