Spaces:
Running
Running
phyloforfun
commited on
Commit
·
524a99c
1
Parent(s):
0560c52
Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing
Browse files- app.py +463 -861
- install_dependencies.sh +85 -0
- pages/faqs.py +38 -0
- pages/prompt_builder.py +478 -0
- pages/report_bugs.py +19 -0
- requirements.txt +0 -0
- run_VoucherVision.py +3 -8
- vouchervision/LLM_crew_OpenAI.py +130 -0
- vouchervision/LLM_local_cpu_MistralAI.py +0 -2
- vouchervision/OCR_CRAFT.py +55 -0
- vouchervision/OCR_google_cloud_vision.py +272 -55
- vouchervision/OCR_llava.py +324 -0
- vouchervision/VoucherVision_Config_Builder.py +24 -6
- vouchervision/data_project.py +33 -6
- vouchervision/general_utils.py +11 -8
- vouchervision/llava_test.py +34 -0
- vouchervision/utils_LLM.py +80 -29
- vouchervision/utils_LLM_JSON_validation.py +1 -1
- vouchervision/utils_VoucherVision.py +22 -19
- vouchervision/vouchervision_main.py +12 -5
app.py
CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
|
|
2 |
import yaml, os, json, random, time, re, torch, random, warnings, shutil, sys
|
3 |
import seaborn as sns
|
4 |
import plotly.graph_objs as go
|
5 |
-
from itertools import chain
|
6 |
from PIL import Image
|
7 |
import pandas as pd
|
8 |
from io import BytesIO
|
@@ -15,30 +14,190 @@ from vouchervision.vouchervision_main import voucher_vision
|
|
15 |
from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, validate_dir
|
16 |
from vouchervision.model_maps import ModelMaps
|
17 |
from vouchervision.API_validation import APIvalidation
|
18 |
-
from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file,
|
19 |
-
|
|
|
20 |
|
21 |
|
22 |
#################################################################################################################################################
|
23 |
# Initializations ###############################################################################################################################
|
24 |
#################################################################################################################################################
|
25 |
-
|
26 |
-
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision')
|
27 |
|
28 |
# Parse the 'is_hf' argument and set it in session state
|
29 |
if 'is_hf' not in st.session_state:
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
-
########################################################################################################
|
34 |
-
### ADDED FOR HUGGING FACE ####
|
35 |
-
########################################################################################################
|
36 |
-
print(f"is_hf {st.session_state['is_hf']}")
|
37 |
# Default YAML file path
|
38 |
if 'config' not in st.session_state:
|
39 |
st.session_state.config, st.session_state.dir_home = build_VV_config(loaded_cfg=None)
|
40 |
setup_streamlit_config(st.session_state.dir_home)
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
if 'uploader_idk' not in st.session_state:
|
43 |
st.session_state['uploader_idk'] = 1
|
44 |
if 'input_list_small' not in st.session_state:
|
@@ -60,11 +219,12 @@ if 'dir_uploaded_images_small' not in st.session_state:
|
|
60 |
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
61 |
validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
|
62 |
|
63 |
-
MAX_GALLERY_IMAGES = 20
|
64 |
-
GALLERY_IMAGE_SIZE = 96
|
65 |
|
66 |
|
67 |
|
|
|
|
|
|
|
68 |
def content_input_images(col_left, col_right):
|
69 |
st.write('---')
|
70 |
# col1, col2 = st.columns([2,8])
|
@@ -83,7 +243,7 @@ def content_input_images(col_left, col_right):
|
|
83 |
if st.session_state.is_hf:
|
84 |
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
85 |
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
86 |
-
uploaded_files = st.file_uploader("Upload Images", type=['jpg', 'jpeg'], accept_multiple_files=True, key=st.session_state['uploader_idk'])
|
87 |
st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
|
88 |
|
89 |
with col_right:
|
@@ -92,27 +252,37 @@ def content_input_images(col_left, col_right):
|
|
92 |
# Clear input image gallery and input list
|
93 |
clear_image_gallery()
|
94 |
|
95 |
-
# Process the new iamges
|
96 |
for uploaded_file in uploaded_files:
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
if st.session_state['input_list_small']:
|
118 |
if len(st.session_state['input_list_small']) > MAX_GALLERY_IMAGES:
|
@@ -150,7 +320,6 @@ def content_input_images(col_left, col_right):
|
|
150 |
st.session_state['dir_images_local_TEMP'] = st.session_state.config['leafmachine']['project']['dir_images_local']
|
151 |
print("rerun")
|
152 |
st.rerun()
|
153 |
-
|
154 |
|
155 |
def list_jpg_files(directory_path):
|
156 |
jpg_count = 0
|
@@ -243,39 +412,14 @@ def use_test_image():
|
|
243 |
st.session_state['input_list_small'].append(file_path_small)
|
244 |
|
245 |
|
246 |
-
def create_download_button_yaml(file_path, selected_yaml_file, key_val):
|
247 |
-
file_label = f"Download {selected_yaml_file}"
|
248 |
-
with open(file_path, 'rb') as f:
|
249 |
-
st.download_button(
|
250 |
-
label=file_label,
|
251 |
-
data=f,
|
252 |
-
file_name=os.path.basename(file_path),
|
253 |
-
mime='application/x-yaml',use_container_width=True,key=key_val,
|
254 |
-
)
|
255 |
-
|
256 |
-
|
257 |
-
def upload_local_prompt_to_server(dir_prompt):
|
258 |
-
uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
|
259 |
-
if uploaded_file is not None:
|
260 |
-
# Check the file extension
|
261 |
-
file_name = uploaded_file.name
|
262 |
-
if file_name.endswith('.yaml'):
|
263 |
-
file_path = os.path.join(dir_prompt, file_name)
|
264 |
-
|
265 |
-
# Save the file
|
266 |
-
with open(file_path, 'wb') as f:
|
267 |
-
f.write(uploaded_file.getbuffer())
|
268 |
-
st.success(f"Saved file {file_name} in {dir_prompt}")
|
269 |
-
else:
|
270 |
-
st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
|
271 |
-
|
272 |
-
|
273 |
def refresh():
|
274 |
st.session_state['uploader_idk'] += 1
|
275 |
st.write('')
|
276 |
|
277 |
|
278 |
|
|
|
|
|
279 |
# def display_image_gallery():
|
280 |
# # Initialize the container
|
281 |
# con_image = st.empty()
|
@@ -516,10 +660,7 @@ class JSONReport:
|
|
516 |
|
517 |
|
518 |
|
519 |
-
|
520 |
-
dir_home = os.path.dirname(__file__)
|
521 |
-
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
522 |
-
return os.path.exists(path_cfg_private)
|
523 |
|
524 |
|
525 |
|
@@ -971,534 +1112,14 @@ def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version
|
|
971 |
# st.session_state.private_file = does_private_file_exist()
|
972 |
|
973 |
# Function to load a YAML file and update session_state
|
974 |
-
|
975 |
-
st.session_state['user_clicked_load_prompt_yaml'] = filename
|
976 |
-
with open(filename, 'r') as file:
|
977 |
-
st.session_state['prompt_info'] = yaml.safe_load(file)
|
978 |
-
st.session_state['prompt_author'] = st.session_state['prompt_info'].get('prompt_author', st.session_state['default_prompt_author'])
|
979 |
-
st.session_state['prompt_author_institution'] = st.session_state['prompt_info'].get('prompt_author_institution', st.session_state['default_prompt_author_institution'])
|
980 |
-
st.session_state['prompt_name'] = st.session_state['prompt_info'].get('prompt_name', st.session_state['default_prompt_name'])
|
981 |
-
st.session_state['prompt_version'] = st.session_state['prompt_info'].get('prompt_version', st.session_state['default_prompt_version'])
|
982 |
-
st.session_state['prompt_description'] = st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description'])
|
983 |
-
st.session_state['instructions'] = st.session_state['prompt_info'].get('instructions', st.session_state['default_instructions'])
|
984 |
-
st.session_state['json_formatting_instructions'] = st.session_state['prompt_info'].get('json_formatting_instructions', st.session_state['default_json_formatting_instructions'] )
|
985 |
-
st.session_state['rules'] = st.session_state['prompt_info'].get('rules', {})
|
986 |
-
st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
|
987 |
-
st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
|
988 |
-
|
989 |
-
# Placeholder:
|
990 |
-
st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
|
991 |
|
992 |
### Updated to match HF version
|
993 |
# def save_prompt_yaml(filename):
|
994 |
-
def save_prompt_yaml(filename, col):
|
995 |
-
yaml_content = {
|
996 |
-
'prompt_author': st.session_state['prompt_author'],
|
997 |
-
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
998 |
-
'prompt_name': st.session_state['prompt_name'],
|
999 |
-
'prompt_version': st.session_state['prompt_version'],
|
1000 |
-
'prompt_description': st.session_state['prompt_description'],
|
1001 |
-
'LLM': st.session_state['LLM'],
|
1002 |
-
'instructions': st.session_state['instructions'],
|
1003 |
-
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
1004 |
-
'rules': st.session_state['rules'],
|
1005 |
-
'mapping': st.session_state['mapping'],
|
1006 |
-
}
|
1007 |
-
|
1008 |
-
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
1009 |
-
filepath = os.path.join(dir_prompt, f"{filename}.yaml")
|
1010 |
-
|
1011 |
-
with open(filepath, 'w') as file:
|
1012 |
-
yaml.safe_dump(dict(yaml_content), file, sort_keys=False)
|
1013 |
-
|
1014 |
-
st.success(f"Prompt saved as '{filename}.yaml'.")
|
1015 |
-
|
1016 |
-
with col: # added
|
1017 |
-
create_download_button_yaml(filepath, filename,key_val=2456237465) # added
|
1018 |
-
|
1019 |
-
def check_unique_mapping_assignments():
|
1020 |
-
print(st.session_state['assigned_columns'])
|
1021 |
-
if len(st.session_state['assigned_columns']) != len(set(st.session_state['assigned_columns'])):
|
1022 |
-
st.error("Each column name must be assigned to only one category.")
|
1023 |
-
return False
|
1024 |
-
elif not st.session_state['assigned_columns']:
|
1025 |
-
st.error("No columns have been mapped.")
|
1026 |
-
return False
|
1027 |
-
elif len(st.session_state['assigned_columns']) != len(st.session_state['rules'].keys()):
|
1028 |
-
incomplete = [item for item in list(st.session_state['rules'].keys()) if item not in st.session_state['assigned_columns']]
|
1029 |
-
st.warning(f"These columns have been mapped: {st.session_state['assigned_columns']}")
|
1030 |
-
st.error(f"However, these columns must be mapped before the prompt is complete: {incomplete}")
|
1031 |
-
return False
|
1032 |
-
else:
|
1033 |
-
st.success("Mapping confirmed.")
|
1034 |
-
return True
|
1035 |
-
|
1036 |
-
def check_prompt_yaml_filename(fname):
|
1037 |
-
# Check if the filename only contains letters, numbers, underscores, and dashes
|
1038 |
-
pattern = r'^[\w-]+$'
|
1039 |
-
|
1040 |
-
# The \w matches any alphanumeric character and is equivalent to the character class [a-zA-Z0-9_].
|
1041 |
-
# The hyphen - is literally matched.
|
1042 |
-
|
1043 |
-
if re.match(pattern, fname):
|
1044 |
-
return True
|
1045 |
-
else:
|
1046 |
-
return False
|
1047 |
-
|
1048 |
-
|
1049 |
-
def btn_load_prompt(selected_yaml_file, dir_prompt):
|
1050 |
-
if selected_yaml_file:
|
1051 |
-
yaml_file_path = os.path.join(dir_prompt, selected_yaml_file)
|
1052 |
-
load_prompt_yaml(yaml_file_path)
|
1053 |
-
elif not selected_yaml_file:
|
1054 |
-
# Directly assigning default values since no file is selected
|
1055 |
-
st.session_state['prompt_info'] = {}
|
1056 |
-
st.session_state['prompt_author'] = st.session_state['default_prompt_author']
|
1057 |
-
st.session_state['prompt_author_institution'] = st.session_state['default_prompt_author_institution']
|
1058 |
-
st.session_state['prompt_name'] = st.session_state['prompt_name']
|
1059 |
-
st.session_state['prompt_version'] = st.session_state['prompt_version']
|
1060 |
-
st.session_state['prompt_description'] = st.session_state['default_prompt_description']
|
1061 |
-
st.session_state['instructions'] = st.session_state['default_instructions']
|
1062 |
-
st.session_state['json_formatting_instructions'] = st.session_state['default_json_formatting_instructions']
|
1063 |
-
st.session_state['rules'] = {}
|
1064 |
-
st.session_state['LLM'] = 'General Purpose'
|
1065 |
-
|
1066 |
-
st.session_state['assigned_columns'] = []
|
1067 |
-
|
1068 |
-
st.session_state['prompt_info'] = {
|
1069 |
-
'prompt_author': st.session_state['prompt_author'],
|
1070 |
-
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
1071 |
-
'prompt_name': st.session_state['prompt_name'],
|
1072 |
-
'prompt_version': st.session_state['prompt_version'],
|
1073 |
-
'prompt_description': st.session_state['prompt_description'],
|
1074 |
-
'instructions': st.session_state['instructions'],
|
1075 |
-
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
1076 |
-
'rules': st.session_state['rules'],
|
1077 |
-
'mapping': st.session_state['mapping'],
|
1078 |
-
'LLM': st.session_state['LLM']
|
1079 |
-
}
|
1080 |
-
|
1081 |
-
def build_LLM_prompt_config():
|
1082 |
-
col_main1, col_main2 = st.columns([10,2])
|
1083 |
-
with col_main1:
|
1084 |
-
st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
|
1085 |
-
st.session_state.logo = Image.open(st.session_state.logo_path)
|
1086 |
-
st.image(st.session_state.logo, width=250)
|
1087 |
-
with col_main2:
|
1088 |
-
if st.button('Exit',key='exist button 2'):
|
1089 |
-
st.session_state.proceed_to_build_llm_prompt = False
|
1090 |
-
st.session_state.proceed_to_main = True
|
1091 |
-
st.rerun()
|
1092 |
-
|
1093 |
-
st.session_state['assigned_columns'] = []
|
1094 |
-
st.session_state['default_prompt_author'] = 'unknown'
|
1095 |
-
st.session_state['default_prompt_author_institution'] = 'unknown'
|
1096 |
-
st.session_state['default_prompt_name'] = 'custom_prompt'
|
1097 |
-
st.session_state['default_prompt_version'] = 'v-1-0'
|
1098 |
-
st.session_state['default_prompt_author_institution'] = 'unknown'
|
1099 |
-
st.session_state['default_prompt_description'] = 'unknown'
|
1100 |
-
st.session_state['default_LLM'] = 'General Purpose'
|
1101 |
-
st.session_state['default_instructions'] = """1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
1102 |
-
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
1103 |
-
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
1104 |
-
4. Duplicate dictionary fields are not allowed.
|
1105 |
-
5. Ensure all JSON keys are in camel case.
|
1106 |
-
6. Ensure new JSON field values follow sentence case capitalization.
|
1107 |
-
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1108 |
-
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1109 |
-
9. Only return a JSON dictionary represented as a string. You should not explain your answer."""
|
1110 |
-
st.session_state['default_json_formatting_instructions'] = """This section provides rules for formatting each JSON value organized by the JSON key."""
|
1111 |
-
|
1112 |
-
# Start building the Streamlit app
|
1113 |
-
col_prompt_main_left, ___, col_prompt_main_right = st.columns([6,1,3])
|
1114 |
-
|
1115 |
-
|
1116 |
-
with col_prompt_main_left:
|
1117 |
-
|
1118 |
-
st.title("Custom LLM Prompt Builder")
|
1119 |
-
st.subheader('About')
|
1120 |
-
st.write("This form allows you to craft a prompt for your specific task. You can also edit the JSON yaml files directly, but please try loading the prompt back into this form to ensure that the formatting is correct. If this form cannot load your manually edited JSON yaml file, then it will not work in VoucherVision.")
|
1121 |
-
st.subheader(':rainbow[How it Works]')
|
1122 |
-
st.write("1. Edit this page until you are happy with your instructions. We recommend looking at the basic structure, writing down your prompt inforamtion in a Word document so that it does not randomly disappear, and then copying and pasting that info into this form once your whole prompt structure is defined.")
|
1123 |
-
st.write("2. After you enter all of your prompt instructions, click 'Save' and give your file a name.")
|
1124 |
-
st.write("3. This file will be saved as a yaml configuration file in the `..VoucherVision/custom_prompts` folder.")
|
1125 |
-
st.write("4. When you go back the main VoucherVision page you will now see your custom prompt available in the 'Prompt Version' dropdown menu.")
|
1126 |
-
st.write("5. The LLM ***only*** sees information from the 'instructions', 'rules', and 'json_formatting_instructions' sections. All other information is for versioning and integration with VoucherVisionEditor.")
|
1127 |
-
|
1128 |
-
st.write("---")
|
1129 |
-
st.header('Load an Existing Prompt Template')
|
1130 |
-
st.write("By default, this form loads the minimum required transcription fields but does not provide rules for each field. You can also load an existing prompt as a template, editing or deleting values as needed.")
|
1131 |
-
|
1132 |
-
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
1133 |
-
yaml_files = [f for f in os.listdir(dir_prompt) if f.endswith('.yaml')]
|
1134 |
-
col_load_text, col_load_btn, col_load_btn2 = st.columns([8,2,2])
|
1135 |
-
with col_load_text:
|
1136 |
-
# Dropdown for selecting a YAML file
|
1137 |
-
st.session_state['selected_yaml_file'] = st.selectbox('Select a prompt .YAML file to load:', [''] + yaml_files)
|
1138 |
-
with col_load_btn:
|
1139 |
-
st.write('##')
|
1140 |
-
# Button to load the selected prompt
|
1141 |
-
st.button('Load Prompt', on_click=btn_load_prompt, args=[st.session_state['selected_yaml_file'], dir_prompt],use_container_width=True)
|
1142 |
-
|
1143 |
-
with col_load_btn2:
|
1144 |
-
if st.session_state['selected_yaml_file']:
|
1145 |
-
# Construct the full path to the file
|
1146 |
-
download_file_path = os.path.join(dir_prompt, st.session_state['selected_yaml_file'] )
|
1147 |
-
# Create the download button
|
1148 |
-
st.write('##')
|
1149 |
-
create_download_button_yaml(download_file_path, st.session_state['selected_yaml_file'],key_val=345798)
|
1150 |
-
|
1151 |
-
# Prompt Author Information
|
1152 |
-
st.write("---")
|
1153 |
-
st.header("Prompt Author Information")
|
1154 |
-
st.write("We value community contributions! Please provide your name(s) (or pseudonym if you prefer) for credit. If you leave this field blank, it will say 'unknown'.")
|
1155 |
-
if 'prompt_author' not in st.session_state:# != st.session_state['default_prompt_author']:
|
1156 |
-
st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['default_prompt_author'],key=1111)
|
1157 |
-
else:
|
1158 |
-
st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['prompt_author'],key=1112)
|
1159 |
-
|
1160 |
-
# Institution
|
1161 |
-
st.write("Please provide your institution name. If you leave this field blank, it will say 'unknown'.")
|
1162 |
-
if 'prompt_author_institution' not in st.session_state:
|
1163 |
-
st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['default_prompt_author_institution'],key=1113)
|
1164 |
-
else:
|
1165 |
-
st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['prompt_author_institution'],key=1114)
|
1166 |
-
|
1167 |
-
# Prompt name
|
1168 |
-
st.write("Please provide a simple name for your prompt. If you leave this field blank, it will say 'custom_prompt'.")
|
1169 |
-
if 'prompt_name' not in st.session_state:
|
1170 |
-
st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['default_prompt_name'],key=1115)
|
1171 |
-
else:
|
1172 |
-
st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['prompt_name'],key=1116)
|
1173 |
-
|
1174 |
-
# Prompt verion
|
1175 |
-
st.write("Please provide a version identifier for your prompt. If you leave this field blank, it will say 'v-1-0'.")
|
1176 |
-
if 'prompt_version' not in st.session_state:
|
1177 |
-
st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['default_prompt_version'],key=1117)
|
1178 |
-
else:
|
1179 |
-
st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['prompt_version'],key=1118)
|
1180 |
-
|
1181 |
-
|
1182 |
-
st.write("Please provide a description of your prompt and its intended task. Is it designed for a specific collection? Taxa? Database structure?")
|
1183 |
-
if 'prompt_description' not in st.session_state:
|
1184 |
-
st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['default_prompt_description'],key=1119)
|
1185 |
-
else:
|
1186 |
-
st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['prompt_description'],key=11111)
|
1187 |
-
|
1188 |
-
st.write('---')
|
1189 |
-
st.header("Set LLM Model Type")
|
1190 |
-
# Define the options for the dropdown
|
1191 |
-
llm_options_general = ["General Purpose",
|
1192 |
-
"OpenAI GPT Models","Google PaLM2 Models","Google Gemini Models","MistralAI Models",]
|
1193 |
-
llm_options_all = ModelMaps.get_models_gui_list()
|
1194 |
-
|
1195 |
-
if 'LLM' not in st.session_state:
|
1196 |
-
st.session_state['LLM'] = st.session_state['default_LLM']
|
1197 |
-
|
1198 |
-
if st.session_state['LLM']:
|
1199 |
-
llm_options = llm_options_general + llm_options_all + [st.session_state['LLM']]
|
1200 |
-
else:
|
1201 |
-
llm_options = llm_options_general + llm_options_all
|
1202 |
-
# Create the dropdown and set the value to session_state['LLM']
|
1203 |
-
st.write("Which LLM is this prompt designed for? This will not restrict its use to a specific LLM, but some prompts will behave differently across models.")
|
1204 |
-
st.write("SLTPvA prompts have been validated with all supported LLMs, but perfornce may vary. If you design a prompt to work best with a specific model, then you can indicate the model here.")
|
1205 |
-
st.write("For general purpose prompts (like the SLTPvA prompts) just use the 'General Purpose' option.")
|
1206 |
-
st.session_state['LLM'] = st.selectbox('Set LLM', llm_options, index=llm_options.index(st.session_state.get('LLM', 'General Purpose')))
|
1207 |
-
|
1208 |
-
st.write('---')
|
1209 |
-
# Instructions Section
|
1210 |
-
st.header("Instructions")
|
1211 |
-
st.write("These are the general instructions that guide the LLM through the transcription task. We recommend using the default instructions unless you have a specific reason to change them.")
|
1212 |
-
|
1213 |
-
if 'instructions' not in st.session_state:
|
1214 |
-
st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['default_instructions'].strip(), height=350,key=111112)
|
1215 |
-
else:
|
1216 |
-
st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['instructions'].strip(), height=350,key=111112)
|
1217 |
-
|
1218 |
-
|
1219 |
-
st.write('---')
|
1220 |
-
|
1221 |
-
# Column Instructions Section
|
1222 |
-
st.header("JSON Formatting Instructions")
|
1223 |
-
st.write("The following section tells the LLM how we want to structure the JSON dictionary. We do not recommend changing this section because it would likely result in unstable and inconsistent behavior.")
|
1224 |
-
if 'json_formatting_instructions' not in st.session_state:
|
1225 |
-
st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['default_json_formatting_instructions'],key=111114)
|
1226 |
-
else:
|
1227 |
-
st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['json_formatting_instructions'],key=111115)
|
1228 |
-
|
1229 |
-
|
1230 |
-
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
1234 |
-
st.write('---')
|
1235 |
-
col_left, col_right = st.columns([6,4])
|
1236 |
-
|
1237 |
-
null_value_rules = ''
|
1238 |
-
c_name = "EXAMPLE_COLUMN_NAME"
|
1239 |
-
c_value = "REPLACE WITH DESCRIPTION"
|
1240 |
-
|
1241 |
-
with col_left:
|
1242 |
-
st.subheader('Add/Edit Columns')
|
1243 |
-
st.markdown("The pre-populated fields are REQUIRED for downstream validation steps. They must be in all prompts.")
|
1244 |
-
|
1245 |
-
# Initialize rules in session state if not already present
|
1246 |
-
if 'rules' not in st.session_state or not st.session_state['rules']:
|
1247 |
-
for required_col in st.session_state['required_fields']:
|
1248 |
-
st.session_state['rules'][required_col] = c_value
|
1249 |
-
|
1250 |
-
|
1251 |
-
|
1252 |
-
|
1253 |
-
# Layout for adding a new column name
|
1254 |
-
# col_text, col_textbtn = st.columns([8, 2])
|
1255 |
-
# with col_text:
|
1256 |
-
st.session_state['new_column_name'] = st.text_input("Enter a new column name:")
|
1257 |
-
# with col_textbtn:
|
1258 |
-
# st.write('##')
|
1259 |
-
if st.button("Add New Column") and st.session_state['new_column_name']:
|
1260 |
-
if st.session_state['new_column_name'] not in st.session_state['rules']:
|
1261 |
-
st.session_state['rules'][st.session_state['new_column_name']] = c_value
|
1262 |
-
st.success(f"New column '{st.session_state['new_column_name']}' added. Now you can edit its properties.")
|
1263 |
-
st.session_state['new_column_name'] = ''
|
1264 |
-
else:
|
1265 |
-
st.error("Column name already exists. Please enter a unique column name.")
|
1266 |
-
st.session_state['new_column_name'] = ''
|
1267 |
-
|
1268 |
-
|
1269 |
-
# Get columns excluding the protected "catalogNumber"
|
1270 |
-
st.write('#')
|
1271 |
-
# required_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
|
1272 |
-
editable_columns = [col for col in st.session_state['rules'] if col not in ["catalogNumber"]]
|
1273 |
-
removable_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
|
1274 |
-
|
1275 |
-
st.session_state['current_rule'] = st.selectbox("Select a column to edit:", [""] + editable_columns)
|
1276 |
-
# column_name = st.selectbox("Select a column to edit:", editable_columns)
|
1277 |
-
|
1278 |
-
|
1279 |
-
|
1280 |
-
# if 'current_rule' not in st.session_state:
|
1281 |
-
# st.session_state['current_rule'] = current_rule
|
1282 |
-
|
1283 |
-
|
1284 |
-
|
1285 |
-
|
1286 |
-
|
1287 |
-
# Form for input fields
|
1288 |
-
with st.form(key='rule_form'):
|
1289 |
-
# format_options = ["verbatim transcription", "spell check transcription", "boolean yes no", "boolean 1 0", "integer", "[list]", "yyyy-mm-dd"]
|
1290 |
-
# current_rule["format"] = st.selectbox("Format:", format_options, index=format_options.index(current_rule["format"]) if current_rule["format"] else 0)
|
1291 |
-
# current_rule["null_value"] = st.text_input("Null value:", value=current_rule["null_value"])
|
1292 |
-
if st.session_state['current_rule']:
|
1293 |
-
current_rule_description = st.text_area("Description of category:", value=st.session_state['rules'][st.session_state['current_rule']])
|
1294 |
-
else:
|
1295 |
-
current_rule_description = ''
|
1296 |
-
commit_button = st.form_submit_button("Commit Column")
|
1297 |
-
|
1298 |
-
# default_rule = {
|
1299 |
-
# "format": format_options[0], # default format
|
1300 |
-
# "null_value": "", # default null value
|
1301 |
-
# "description": "", # default description
|
1302 |
-
# }
|
1303 |
-
# if st.session_state['current_rule'] != st.session_state['current_rule']:
|
1304 |
-
# # Column has changed. Update the session_state selected column.
|
1305 |
-
# st.session_state['current_rule'] = st.session_state['current_rule']
|
1306 |
-
# # Reset the current rule to the default for this new column, or a blank rule if not set.
|
1307 |
-
# current_rule = st.session_state['rules'][st.session_state['current_rule']].get(current_rule, c_value)
|
1308 |
-
|
1309 |
-
# Handle commit action
|
1310 |
-
if commit_button and st.session_state['current_rule']:
|
1311 |
-
# Commit the rules to the session state.
|
1312 |
-
st.session_state['rules'][st.session_state['current_rule']] = current_rule_description
|
1313 |
-
st.success(f"Column '{st.session_state['current_rule']}' added/updated in rules.")
|
1314 |
-
|
1315 |
-
# Force the form to reset by clearing the fields from the session state
|
1316 |
-
st.session_state.pop('current_rule', None) # Clear the selected column to force reset
|
1317 |
-
|
1318 |
-
# st.session_state['rules'][column_name] = current_rule
|
1319 |
-
# st.success(f"Column '{column_name}' added/updated in rules.")
|
1320 |
-
|
1321 |
-
# # Reset current_rule to default values for the next input
|
1322 |
-
# current_rule["format"] = default_rule["format"]
|
1323 |
-
# current_rule["null_value"] = default_rule["null_value"]
|
1324 |
-
# current_rule["description"] = default_rule["description"]
|
1325 |
-
|
1326 |
-
# # To ensure that the form fields are reset, we can clear them from the session state
|
1327 |
-
# for key in current_rule.keys():
|
1328 |
-
# st.session_state[key] = default_rule[key]
|
1329 |
-
|
1330 |
-
# Layout for removing an existing column
|
1331 |
-
# del_col, del_colbtn = st.columns([8, 2])
|
1332 |
-
# with del_col:
|
1333 |
-
delete_column_name = st.selectbox("Select a column to delete:", [""] + removable_columns)
|
1334 |
-
# with del_colbtn:
|
1335 |
-
# st.write('##')
|
1336 |
-
if st.button("Delete Column") and delete_column_name:
|
1337 |
-
del st.session_state['rules'][delete_column_name]
|
1338 |
-
st.success(f"Column '{delete_column_name}' removed from rules.")
|
1339 |
-
|
1340 |
|
1341 |
-
|
1342 |
-
|
1343 |
-
with col_right:
|
1344 |
-
# Display the current state of the JSON rules
|
1345 |
-
st.subheader('Formatted Columns')
|
1346 |
-
st.json(st.session_state['rules'])
|
1347 |
-
|
1348 |
-
# st.subheader('All Prompt Info')
|
1349 |
-
# st.json(st.session_state['prompt_info'])
|
1350 |
-
|
1351 |
-
|
1352 |
-
st.write('---')
|
1353 |
-
|
1354 |
-
|
1355 |
-
col_left_mapping, col_right_mapping = st.columns([6,4])
|
1356 |
-
with col_left_mapping:
|
1357 |
-
st.header("Mapping")
|
1358 |
-
st.write("Assign each column name to a single category.")
|
1359 |
-
st.session_state['refresh_mapping'] = False
|
1360 |
-
|
1361 |
-
# Dynamically create a list of all column names that can be assigned
|
1362 |
-
# This assumes that the column names are the keys in the dictionary under 'rules'
|
1363 |
-
all_column_names = list(st.session_state['rules'].keys())
|
1364 |
-
|
1365 |
-
categories = ['TAXONOMY', 'GEOGRAPHY', 'LOCALITY', 'COLLECTING', 'MISC']
|
1366 |
-
if ('mapping' not in st.session_state) or (st.session_state['mapping'] == {}):
|
1367 |
-
st.session_state['mapping'] = {category: [] for category in categories}
|
1368 |
-
for category in categories:
|
1369 |
-
# Filter out the already assigned columns
|
1370 |
-
available_columns = [col for col in all_column_names if col not in st.session_state['assigned_columns'] or col in st.session_state['mapping'].get(category, [])]
|
1371 |
-
|
1372 |
-
# Ensure the current mapping is a subset of the available options
|
1373 |
-
current_mapping = [col for col in st.session_state['mapping'].get(category, []) if col in available_columns]
|
1374 |
-
|
1375 |
-
# Provide a safe default if the current mapping is empty or contains invalid options
|
1376 |
-
safe_default = current_mapping if all(col in available_columns for col in current_mapping) else []
|
1377 |
-
|
1378 |
-
# Create a multi-select widget for the category with a safe default
|
1379 |
-
selected_columns = st.multiselect(
|
1380 |
-
f"Select columns for {category}:",
|
1381 |
-
available_columns,
|
1382 |
-
default=safe_default,
|
1383 |
-
key=f"mapping_{category}"
|
1384 |
-
)
|
1385 |
-
# Update the assigned_columns based on the selections
|
1386 |
-
for col in current_mapping:
|
1387 |
-
if col not in selected_columns and col in st.session_state['assigned_columns']:
|
1388 |
-
st.session_state['assigned_columns'].remove(col)
|
1389 |
-
st.session_state['refresh_mapping'] = True
|
1390 |
-
|
1391 |
-
for col in selected_columns:
|
1392 |
-
if col not in st.session_state['assigned_columns']:
|
1393 |
-
st.session_state['assigned_columns'].append(col)
|
1394 |
-
st.session_state['refresh_mapping'] = True
|
1395 |
-
|
1396 |
-
# Update the mapping in session state when there's a change
|
1397 |
-
st.session_state['mapping'][category] = selected_columns
|
1398 |
-
if st.session_state['refresh_mapping']:
|
1399 |
-
st.session_state['refresh_mapping'] = False
|
1400 |
-
|
1401 |
-
# Button to confirm and save the mapping configuration
|
1402 |
-
if st.button('Confirm Mapping'):
|
1403 |
-
if check_unique_mapping_assignments():
|
1404 |
-
# Proceed with further actions since the mapping is confirmed and unique
|
1405 |
-
pass
|
1406 |
-
|
1407 |
-
with col_right_mapping:
|
1408 |
-
# Display the current state of the JSON rules
|
1409 |
-
st.subheader('Formatted Column Maps')
|
1410 |
-
st.json(st.session_state['mapping'])
|
1411 |
-
|
1412 |
-
|
1413 |
-
col_left_save, col_right_save = st.columns([6,4])
|
1414 |
-
with col_left_save:
|
1415 |
-
# Input for new file name
|
1416 |
-
new_filename = st.text_input("Enter filename to save your prompt as a configuration YAML:",placeholder='my_prompt_name')
|
1417 |
-
# Button to save the new YAML file
|
1418 |
-
if st.button('Save YAML', type='primary'):
|
1419 |
-
if new_filename:
|
1420 |
-
if check_unique_mapping_assignments():
|
1421 |
-
if check_prompt_yaml_filename(new_filename):
|
1422 |
-
save_prompt_yaml(new_filename, col_left_save)
|
1423 |
-
else:
|
1424 |
-
st.error("File name can only contain letters, numbers, underscores, and dashes. Cannot contain spaces.")
|
1425 |
-
else:
|
1426 |
-
st.error("Mapping contains an error. Make sure that each column is assigned to only ***one*** category.")
|
1427 |
-
else:
|
1428 |
-
st.error("Please enter a filename.")
|
1429 |
-
|
1430 |
-
if st.button('Exit'):
|
1431 |
-
st.session_state.proceed_to_build_llm_prompt = False
|
1432 |
-
st.session_state.proceed_to_main = True
|
1433 |
-
st.rerun()
|
1434 |
|
1435 |
|
1436 |
|
1437 |
-
# st.write('---')
|
1438 |
-
# st.header("Save and Download Custom Prompt")
|
1439 |
-
# st.write('Once you click save, validation checks will verify the formatting and then a download button will appear so that you can ***save a local copy of your custom prompt.***')
|
1440 |
-
# col_left_save, col_right_save, _ = st.columns([2,2,8])
|
1441 |
-
# with col_left_save:
|
1442 |
-
# # Button to save the new YAML file
|
1443 |
-
# if st.button('Save YAML', type='primary',key=3450798):
|
1444 |
-
# if st.session_state['prompt_name']:
|
1445 |
-
# if check_unique_mapping_assignments():
|
1446 |
-
# if check_prompt_yaml_filename(st.session_state['prompt_name']):
|
1447 |
-
# save_prompt_yaml(st.session_state['prompt_name'], col_right_save)
|
1448 |
-
# else:
|
1449 |
-
# st.error("File name can only contain letters, numbers, underscores, and dashes. Cannot contain spaces.")
|
1450 |
-
# else:
|
1451 |
-
# st.error("Mapping contains an error. Make sure that each column is assigned to only ***one*** category.")
|
1452 |
-
# else:
|
1453 |
-
# st.error("Please enter a filename.")
|
1454 |
-
|
1455 |
-
# with col_prompt_main_right:
|
1456 |
-
# st.subheader('All Prompt Components')
|
1457 |
-
# st.session_state['prompt_info'] = {
|
1458 |
-
# 'prompt_author': st.session_state['prompt_author'],
|
1459 |
-
# 'prompt_author_institution': st.session_state['prompt_author_institution'],
|
1460 |
-
# 'prompt_name': st.session_state['prompt_name'],
|
1461 |
-
# 'prompt_version': st.session_state['prompt_version'],
|
1462 |
-
# 'prompt_description': st.session_state['prompt_description'],
|
1463 |
-
# 'LLM': st.session_state['LLM'],
|
1464 |
-
# 'instructions': st.session_state['instructions'],
|
1465 |
-
# 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
1466 |
-
# 'rules': st.session_state['rules'],
|
1467 |
-
# 'mapping': st.session_state['mapping'],
|
1468 |
-
# }
|
1469 |
-
# st.json(st.session_state['prompt_info'])
|
1470 |
-
with col_prompt_main_right:
|
1471 |
-
if st.session_state['user_clicked_load_prompt_yaml'] is None: # see if user has loaded a yaml to edit
|
1472 |
-
st.session_state['show_prompt_name_e'] = f"Prompt Status :arrow_forward: Building prompt from scratch"
|
1473 |
-
if st.session_state['prompt_name']:
|
1474 |
-
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
|
1475 |
-
else:
|
1476 |
-
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
|
1477 |
-
else:
|
1478 |
-
st.session_state['show_prompt_name_e'] = f"Prompt Status: Editing :arrow_forward: {st.session_state['selected_yaml_file']}"
|
1479 |
-
if st.session_state['prompt_name']:
|
1480 |
-
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
|
1481 |
-
else:
|
1482 |
-
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
|
1483 |
-
|
1484 |
-
st.subheader(f'Full Prompt')
|
1485 |
-
st.write(st.session_state['show_prompt_name_e'])
|
1486 |
-
st.write(st.session_state['show_prompt_name_w'])
|
1487 |
-
st.write("---")
|
1488 |
-
st.session_state['prompt_info'] = {
|
1489 |
-
'prompt_author': st.session_state['prompt_author'],
|
1490 |
-
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
1491 |
-
'prompt_name': st.session_state['prompt_name'],
|
1492 |
-
'prompt_version': st.session_state['prompt_version'],
|
1493 |
-
'prompt_description': st.session_state['prompt_description'],
|
1494 |
-
'LLM': st.session_state['LLM'],
|
1495 |
-
'instructions': st.session_state['instructions'],
|
1496 |
-
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
1497 |
-
'rules': st.session_state['rules'],
|
1498 |
-
'mapping': st.session_state['mapping'],
|
1499 |
-
}
|
1500 |
-
st.json(st.session_state['prompt_info'])
|
1501 |
-
|
1502 |
def show_header_welcome():
|
1503 |
st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
|
1504 |
st.session_state.logo = Image.open(st.session_state.logo_path)
|
@@ -1676,7 +1297,7 @@ def content_header():
|
|
1676 |
with col_run_4:
|
1677 |
with st.expander("View Messages and Updates"):
|
1678 |
st.info("***Note:*** If you use VoucherVision frequently, you can change the default values that are auto-populated in the form below. In a text editor or IDE, edit the first few rows in the file `../VoucherVision/vouchervision/VoucherVision_Config_Builder.py`")
|
1679 |
-
|
1680 |
|
1681 |
col_test = st.container()
|
1682 |
|
@@ -1686,13 +1307,6 @@ def content_header():
|
|
1686 |
col_json, col_json_WFO, col_json_GEO, col_json_map = st.columns([2, 2, 2, 2])
|
1687 |
|
1688 |
with col_run_info_1:
|
1689 |
-
# Progress
|
1690 |
-
# Progress
|
1691 |
-
# st.subheader('Project')
|
1692 |
-
# bar = st.progress(0)
|
1693 |
-
# new_text = st.empty() # Placeholder for current step name
|
1694 |
-
# progress_report = ProgressReportVV(bar, new_text, n_images=10)
|
1695 |
-
|
1696 |
# Progress
|
1697 |
overall_progress_bar = st.progress(0)
|
1698 |
text_overall = st.empty() # Placeholder for current step name
|
@@ -1700,23 +1314,14 @@ def content_header():
|
|
1700 |
batch_progress_bar = st.progress(0)
|
1701 |
text_batch = st.empty() # Placeholder for current step name
|
1702 |
progress_report = ProgressReport(overall_progress_bar, batch_progress_bar, text_overall, text_batch)
|
1703 |
-
# st.session_state['json_report'] = JSONReport(col_updates_1, col_json, col_json_WFO, col_json_GEO, col_json_map)
|
1704 |
st.session_state['hold_output'] = st.toggle('View Final Transcription')
|
1705 |
|
1706 |
with col_logo:
|
1707 |
show_header_welcome()
|
1708 |
|
1709 |
with col_run_1:
|
1710 |
-
# st.subheader('Run VoucherVision')
|
1711 |
N_STEPS = 6
|
1712 |
|
1713 |
-
# if st.session_state.is_hf:
|
1714 |
-
# count_n_imgs = determine_n_images()
|
1715 |
-
# if count_n_imgs > 0:
|
1716 |
-
# st.session_state['processing_add_on'] = count_n_imgs
|
1717 |
-
# else:
|
1718 |
-
# st.session_state['processing_add_on'] = 0
|
1719 |
-
|
1720 |
if check_if_usable(is_hf=st.session_state['is_hf']):
|
1721 |
b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
|
1722 |
if st.session_state['processing_add_on'] == 0:
|
@@ -1740,21 +1345,28 @@ def content_header():
|
|
1740 |
total_cost = 0.00
|
1741 |
n_failed_OCR = 0
|
1742 |
n_failed_LLM_calls = 0
|
1743 |
-
try:
|
1744 |
-
|
1745 |
-
|
1746 |
-
|
1747 |
-
|
1748 |
-
|
1749 |
-
|
1750 |
-
|
1751 |
-
|
1752 |
-
|
1753 |
-
|
1754 |
-
|
1755 |
-
|
1756 |
-
|
1757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1758 |
|
1759 |
if n_failed_OCR > 0:
|
1760 |
with col_run_4:
|
@@ -1791,8 +1403,13 @@ def content_header():
|
|
1791 |
with ct_left:
|
1792 |
st.button("Refresh", on_click=refresh, use_container_width=True)
|
1793 |
with ct_right:
|
1794 |
-
|
1795 |
-
|
|
|
|
|
|
|
|
|
|
|
1796 |
|
1797 |
# with col_run_2:
|
1798 |
# if st.button("Test GPT"):
|
@@ -1869,14 +1486,6 @@ def content_header():
|
|
1869 |
|
1870 |
|
1871 |
|
1872 |
-
|
1873 |
-
|
1874 |
-
|
1875 |
-
|
1876 |
-
|
1877 |
-
|
1878 |
-
|
1879 |
-
|
1880 |
def content_project_settings(col):
|
1881 |
### Project
|
1882 |
with col:
|
@@ -1966,9 +1575,10 @@ def content_prompt_and_llm_version():
|
|
1966 |
st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
|
1967 |
|
1968 |
with col_prompt_2:
|
1969 |
-
if st.button("Build Custom LLM Prompt"):
|
1970 |
-
|
1971 |
-
|
|
|
1972 |
|
1973 |
st.header('LLM Version')
|
1974 |
col_llm_1, col_llm_2 = st.columns([4,2])
|
@@ -2004,13 +1614,66 @@ def content_api_check():
|
|
2004 |
st.rerun()
|
2005 |
|
2006 |
|
2007 |
-
|
2008 |
-
|
2009 |
|
2010 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2011 |
st.write("---")
|
2012 |
-
|
2013 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2014 |
demo_text_h = f"Google_OCR_Handwriting:\nHERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 TX 11 Ilowers pink UNIVERSITE HERBARIUM MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
|
2015 |
demo_text_tr = f"trOCR:\nherbarium of marcus w. lyon jr. : : : tracaulon sagittatum indiana porter co. incal springs TX 11 Ilowers pink 1439649 copyright reserved D H U Q "
|
2016 |
demo_text_p = f"Google_OCR_Printed:\nTracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 Ilowers pink 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
|
@@ -2019,11 +1682,125 @@ def content_collage_overlay():
|
|
2019 |
demo_text_trh = demo_text_h + '\n' + demo_text_tr
|
2020 |
demo_text_trp = demo_text_p + '\n' + demo_text_tr
|
2021 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2022 |
with col_collage:
|
2023 |
st.header('LeafMachine2 Label Collage')
|
|
|
2024 |
default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
|
2025 |
st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
|
2026 |
-
st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox("Use LeafMachine2 label collage for transcriptions", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
|
2027 |
|
2028 |
|
2029 |
option_selected_crops = st.multiselect(label="Components to crop",
|
@@ -2040,76 +1817,14 @@ def content_collage_overlay():
|
|
2040 |
with st.expander(":frame_with_picture: View an example of the LeafMachine2 collage image"):
|
2041 |
st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="PNG")
|
2042 |
# st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="JPEG")
|
2043 |
-
|
2044 |
-
|
2045 |
|
2046 |
with col_overlay:
|
2047 |
st.header('OCR Overlay Image')
|
2048 |
-
options = [":rainbow[Printed + Handwritten]", "Printed", "Use both models"]
|
2049 |
-
captions = [
|
2050 |
-
"Works well for both printed and handwritten text",
|
2051 |
-
"Works for printed text",
|
2052 |
-
"Adds both OCR versions to the LLM prompt"
|
2053 |
-
]
|
2054 |
|
2055 |
st.write('This will plot bounding boxes around all text that Google Vision was able to detect. If there are no boxes around text, then the OCR failed, so that missing text will not be seen by the LLM when it is creating the JSON object. The created image will be viewable in the VoucherVisionEditor.')
|
2056 |
|
2057 |
do_create_OCR_helper_image = st.checkbox("Create image showing an overlay of the OCR detections",value=st.session_state.config['leafmachine']['do_create_OCR_helper_image'],disabled=True)
|
2058 |
st.session_state.config['leafmachine']['do_create_OCR_helper_image'] = do_create_OCR_helper_image
|
2059 |
-
|
2060 |
-
|
2061 |
-
|
2062 |
-
|
2063 |
-
# Get the current OCR option from session state
|
2064 |
-
OCR_option = st.session_state.config['leafmachine']['project']['OCR_option']
|
2065 |
-
|
2066 |
-
# Map the OCR option to the index in options list
|
2067 |
-
# You need to define the mapping based on your application's logic
|
2068 |
-
option_to_index = {
|
2069 |
-
'hand': 0,
|
2070 |
-
'normal': 1,
|
2071 |
-
'both': 2,
|
2072 |
-
}
|
2073 |
-
default_index = option_to_index.get(OCR_option, 0) # Default to 0 if option not found
|
2074 |
-
|
2075 |
-
# Create the radio button
|
2076 |
-
OCR_option_select = st.radio(
|
2077 |
-
"Select the Google Vision OCR version.",
|
2078 |
-
options,
|
2079 |
-
index=default_index,
|
2080 |
-
help="",captions=captions,
|
2081 |
-
)
|
2082 |
-
st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option_select
|
2083 |
-
|
2084 |
-
if OCR_option_select == ":rainbow[Printed + Handwritten]":
|
2085 |
-
OCR_option = 'hand'
|
2086 |
-
elif OCR_option_select == "Printed":
|
2087 |
-
OCR_option = 'normal'
|
2088 |
-
elif OCR_option_select == "Use both models":
|
2089 |
-
OCR_option = 'both'
|
2090 |
-
else:
|
2091 |
-
raise
|
2092 |
-
|
2093 |
-
st.write("Supplement Google Vision OCR with trOCR (handwriting OCR) using `microsoft/trocr-base-handwritten`. This option requires Google Vision API and a GPU.")
|
2094 |
-
do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'])#,disabled=st.session_state['lacks_GPU'])
|
2095 |
-
st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
|
2096 |
-
|
2097 |
-
|
2098 |
-
st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option
|
2099 |
-
st.markdown("Below is an example of what the LLM would see given the choice of OCR ensemble. One, two, or three version of OCR can be fed into the LLM prompt. Typically, 'printed + handwritten' works well. If you have a GPU then you can enable trOCR.")
|
2100 |
-
if (OCR_option == 'hand') and not do_use_trOCR:
|
2101 |
-
st.text_area(label='Handwritten/Printed',placeholder=demo_text_h,disabled=True, label_visibility='visible', height=150)
|
2102 |
-
elif (OCR_option == 'normal') and not do_use_trOCR:
|
2103 |
-
st.text_area(label='Printed',placeholder=demo_text_p,disabled=True, label_visibility='visible', height=150)
|
2104 |
-
elif (OCR_option == 'both') and not do_use_trOCR:
|
2105 |
-
st.text_area(label='Handwritten/Printed + Printed',placeholder=demo_text_b,disabled=True, label_visibility='visible', height=150)
|
2106 |
-
elif (OCR_option == 'both') and do_use_trOCR:
|
2107 |
-
st.text_area(label='Handwritten/Printed + Printed + trOCR',placeholder=demo_text_trb,disabled=True, label_visibility='visible', height=150)
|
2108 |
-
elif (OCR_option == 'normal') and do_use_trOCR:
|
2109 |
-
st.text_area(label='Printed + trOCR',placeholder=demo_text_trp,disabled=True, label_visibility='visible', height=150)
|
2110 |
-
elif (OCR_option == 'hand') and do_use_trOCR:
|
2111 |
-
st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
|
2112 |
-
|
2113 |
|
2114 |
if "demo_overlay" not in st.session_state:
|
2115 |
# ocr = os.path.join(st.session_state.dir_home,'demo', 'ba','ocr.png')
|
@@ -2159,6 +1874,8 @@ def content_processing_options():
|
|
2159 |
st.subheader('Compute Options')
|
2160 |
st.session_state.config['leafmachine']['project']['num_workers'] = st.number_input("Number of CPU workers", value=st.session_state.config['leafmachine']['project'].get('num_workers', 1), disabled=False)
|
2161 |
st.session_state.config['leafmachine']['project']['batch_size'] = st.number_input("Batch size", value=st.session_state.config['leafmachine']['project'].get('batch_size', 500), help='Sets the batch size for the LeafMachine2 cropping. If computer RAM is filled, lower this value to ~100.')
|
|
|
|
|
2162 |
with col_processing_2:
|
2163 |
st.subheader('Filename Prefix Handling')
|
2164 |
st.session_state.config['leafmachine']['project']['prefix_removal'] = st.text_input("Remove prefix from catalog number", st.session_state.config['leafmachine']['project'].get('prefix_removal', ''),placeholder="e.g. MICH-V-")
|
@@ -2167,18 +1884,21 @@ def content_processing_options():
|
|
2167 |
|
2168 |
### Logging and Image Validation - col_v1
|
2169 |
st.write("---")
|
2170 |
-
st.header('Logging and Image Validation')
|
2171 |
col_v1, col_v2 = st.columns(2)
|
|
|
2172 |
with col_v1:
|
|
|
2173 |
option_check_illegal = st.checkbox("Check for illegal filenames", value=st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'])
|
2174 |
st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'] = option_check_illegal
|
2175 |
-
|
|
|
|
|
|
|
2176 |
st.session_state.config['leafmachine']['do']['check_for_corrupt_images_make_vertical'] = st.checkbox("Check for corrupt images", st.session_state.config['leafmachine']['do'].get('check_for_corrupt_images_make_vertical', True),disabled=True)
|
2177 |
|
2178 |
st.session_state.config['leafmachine']['print']['verbose'] = st.checkbox("Print verbose", st.session_state.config['leafmachine']['print'].get('verbose', True))
|
2179 |
st.session_state.config['leafmachine']['print']['optional_warnings'] = st.checkbox("Show optional warnings", st.session_state.config['leafmachine']['print'].get('optional_warnings', True))
|
2180 |
-
|
2181 |
-
with col_v2:
|
2182 |
log_level = st.session_state.config['leafmachine']['logging'].get('log_level', None)
|
2183 |
log_level_display = log_level if log_level is not None else 'default'
|
2184 |
selected_log_level = st.selectbox("Logging Level", ['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'], index=['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'].index(log_level_display))
|
@@ -2188,6 +1908,28 @@ def content_processing_options():
|
|
2188 |
else:
|
2189 |
st.session_state.config['leafmachine']['logging']['log_level'] = selected_log_level
|
2190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2191 |
|
2192 |
|
2193 |
def content_tab_domain():
|
@@ -2254,7 +1996,9 @@ def render_expense_report_summary():
|
|
2254 |
expense_report = st.session_state.expense_report
|
2255 |
st.header('Expense Report Summary')
|
2256 |
|
2257 |
-
if expense_summary:
|
|
|
|
|
2258 |
st.metric(label="Total Cost", value=f"${round(expense_summary['total_cost_sum'], 4):,}")
|
2259 |
col1, col2 = st.columns(2)
|
2260 |
|
@@ -2348,19 +2092,21 @@ def render_expense_report_summary():
|
|
2348 |
pie_chart.update_traces(marker=dict(colors=colors),)
|
2349 |
st.plotly_chart(pie_chart, use_container_width=True)
|
2350 |
|
2351 |
-
else:
|
2352 |
-
st.error('No expense report data available.')
|
2353 |
-
|
2354 |
-
|
2355 |
|
2356 |
def content_less_used():
|
2357 |
st.write('---')
|
2358 |
st.write(':octagonal_sign: ***NOTE:*** Settings below are not relevant for most projects. Some settings below may not be reflected in saved settings files and would need to be set each time.')
|
2359 |
|
|
|
2360 |
#################################################################################################################################################
|
2361 |
# Sidebar #######################################################################################################################################
|
2362 |
#################################################################################################################################################
|
2363 |
def sidebar_content():
|
|
|
|
|
|
|
|
|
|
|
2364 |
if not os.path.exists(os.path.join(st.session_state.dir_home,'expense_report')):
|
2365 |
validate_dir(os.path.join(st.session_state.dir_home,'expense_report'))
|
2366 |
expense_report_path = os.path.join(st.session_state.dir_home, 'expense_report', 'expense_report.csv')
|
@@ -2377,7 +2123,6 @@ def sidebar_content():
|
|
2377 |
st.write('Available after first run...')
|
2378 |
|
2379 |
|
2380 |
-
|
2381 |
#################################################################################################################################################
|
2382 |
# Routing Function ##############################################################################################################################
|
2383 |
#################################################################################################################################################
|
@@ -2387,28 +2132,20 @@ def main():
|
|
2387 |
sidebar_content()
|
2388 |
# Main App
|
2389 |
content_header()
|
2390 |
-
|
2391 |
|
2392 |
col_input, col_gallery = st.columns([4,8])
|
2393 |
content_project_settings(col_input)
|
2394 |
content_input_images(col_input, col_gallery)
|
2395 |
|
2396 |
-
# if st.session_state['is_hf']:
|
2397 |
-
# content_project_settings()
|
2398 |
-
# content_input_images_hf()
|
2399 |
-
# else:
|
2400 |
-
# col1, col2 = st.columns([1,1])
|
2401 |
-
# with col1:
|
2402 |
-
# content_project_settings()
|
2403 |
-
# with col2:
|
2404 |
-
# content_input_images()
|
2405 |
-
|
2406 |
|
2407 |
col3, col4 = st.columns([1,1])
|
2408 |
with col3:
|
2409 |
content_prompt_and_llm_version()
|
2410 |
with col4:
|
2411 |
content_api_check()
|
|
|
|
|
|
|
2412 |
content_collage_overlay()
|
2413 |
content_llm_cost()
|
2414 |
content_processing_options()
|
@@ -2418,155 +2155,20 @@ def main():
|
|
2418 |
content_space_saver()
|
2419 |
|
2420 |
|
2421 |
-
|
2422 |
-
|
2423 |
-
|
2424 |
-
|
2425 |
-
#################################################################################################################################################
|
2426 |
-
# Initializations ###############################################################################################################################
|
2427 |
-
#################################################################################################################################################
|
2428 |
-
|
2429 |
-
|
2430 |
-
|
2431 |
-
|
2432 |
-
|
2433 |
-
|
2434 |
-
|
2435 |
-
if st.session_state['is_hf']:
|
2436 |
-
if 'proceed_to_main' not in st.session_state:
|
2437 |
-
st.session_state.proceed_to_main = True
|
2438 |
-
|
2439 |
-
if 'proceed_to_private' not in st.session_state:
|
2440 |
-
st.session_state.proceed_to_private = False
|
2441 |
-
|
2442 |
-
if 'private_file' not in st.session_state:
|
2443 |
-
st.session_state.private_file = True
|
2444 |
-
|
2445 |
-
else:
|
2446 |
-
if 'proceed_to_main' not in st.session_state:
|
2447 |
-
st.session_state.proceed_to_main = False # New state variable to control the flow
|
2448 |
-
|
2449 |
-
if 'private_file' not in st.session_state:
|
2450 |
-
st.session_state.private_file = does_private_file_exist()
|
2451 |
-
if st.session_state.private_file:
|
2452 |
-
st.session_state.proceed_to_main = True
|
2453 |
-
|
2454 |
-
if 'proceed_to_private' not in st.session_state:
|
2455 |
-
st.session_state.proceed_to_private = False # New state variable to control the flow
|
2456 |
-
|
2457 |
-
|
2458 |
-
|
2459 |
-
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
2460 |
-
st.session_state.proceed_to_build_llm_prompt = False # New state variable to control the flow
|
2461 |
-
|
2462 |
-
|
2463 |
-
if 'processing_add_on' not in st.session_state:
|
2464 |
-
st.session_state['processing_add_on'] = 0
|
2465 |
-
|
2466 |
-
|
2467 |
-
if 'formatted_json' not in st.session_state:
|
2468 |
-
st.session_state['formatted_json'] = None
|
2469 |
-
if 'formatted_json_WFO' not in st.session_state:
|
2470 |
-
st.session_state['formatted_json_WFO'] = None
|
2471 |
-
if 'formatted_json_GEO' not in st.session_state:
|
2472 |
-
st.session_state['formatted_json_GEO'] = None
|
2473 |
-
|
2474 |
-
|
2475 |
-
if 'lacks_GPU' not in st.session_state:
|
2476 |
-
st.session_state['lacks_GPU'] = not torch.cuda.is_available()
|
2477 |
-
|
2478 |
-
|
2479 |
-
if 'API_key_validation' not in st.session_state:
|
2480 |
-
st.session_state['API_key_validation'] = False
|
2481 |
-
if 'present_annotations' not in st.session_state:
|
2482 |
-
st.session_state['present_annotations'] = None
|
2483 |
-
if 'missing_annotations' not in st.session_state:
|
2484 |
-
st.session_state['missing_annotations'] = None
|
2485 |
-
if 'date_of_check' not in st.session_state:
|
2486 |
-
st.session_state['date_of_check'] = None
|
2487 |
-
if 'API_checked' not in st.session_state:
|
2488 |
-
st.session_state['API_checked'] = False
|
2489 |
-
if 'API_rechecked' not in st.session_state:
|
2490 |
-
st.session_state['API_rechecked'] = False
|
2491 |
-
|
2492 |
-
|
2493 |
-
if 'json_report' not in st.session_state:
|
2494 |
-
st.session_state['json_report'] = False
|
2495 |
-
if 'hold_output' not in st.session_state:
|
2496 |
-
st.session_state['hold_output'] = False
|
2497 |
-
|
2498 |
-
|
2499 |
-
|
2500 |
-
|
2501 |
-
|
2502 |
-
if 'cost_openai' not in st.session_state:
|
2503 |
-
st.session_state['cost_openai'] = None
|
2504 |
-
if 'cost_azure' not in st.session_state:
|
2505 |
-
st.session_state['cost_azure'] = None
|
2506 |
-
if 'cost_google' not in st.session_state:
|
2507 |
-
st.session_state['cost_google'] = None
|
2508 |
-
if 'cost_mistral' not in st.session_state:
|
2509 |
-
st.session_state['cost_mistral'] = None
|
2510 |
-
if 'cost_local' not in st.session_state:
|
2511 |
-
st.session_state['cost_local'] = None
|
2512 |
-
|
2513 |
-
|
2514 |
-
if 'settings_filename' not in st.session_state:
|
2515 |
-
st.session_state['settings_filename'] = None
|
2516 |
-
if 'loaded_settings_filename' not in st.session_state:
|
2517 |
-
st.session_state['loaded_settings_filename'] = None
|
2518 |
-
if 'zip_filepath' not in st.session_state:
|
2519 |
-
st.session_state['zip_filepath'] = None
|
2520 |
-
|
2521 |
-
|
2522 |
-
# Initialize session_state variables if they don't exist
|
2523 |
-
if 'prompt_info' not in st.session_state:
|
2524 |
-
st.session_state['prompt_info'] = {}
|
2525 |
-
if 'rules' not in st.session_state:
|
2526 |
-
st.session_state['rules'] = {}
|
2527 |
-
|
2528 |
-
|
2529 |
-
# These are the fields that are in SLTPvA that are not required by another parsing valication function:
|
2530 |
-
# "identifiedBy": "M.W. Lyon, Jr.",
|
2531 |
-
# "recordedBy": "University of Michigan Herbarium",
|
2532 |
-
# "recordNumber": "",
|
2533 |
-
# "habitat": "wet subdunal woods",
|
2534 |
-
# "occurrenceRemarks": "Indiana : Porter Co.",
|
2535 |
-
# "degreeOfEstablishment": "",
|
2536 |
-
# "minimumElevationInMeters": "",
|
2537 |
-
# "maximumElevationInMeters": ""
|
2538 |
-
if 'required_fields' not in st.session_state:
|
2539 |
-
st.session_state['required_fields'] = ['catalogNumber','order','family','scientificName',
|
2540 |
-
'scientificNameAuthorship','genus','subgenus','specificEpithet','infraspecificEpithet',
|
2541 |
-
'verbatimEventDate','eventDate',
|
2542 |
-
'country','stateProvince','county','municipality','locality','decimalLatitude','decimalLongitude','verbatimCoordinates',]
|
2543 |
-
|
2544 |
-
|
2545 |
-
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
2546 |
-
st.session_state.proceed_to_build_llm_prompt = False
|
2547 |
-
if 'proceed_to_component_detector' not in st.session_state:
|
2548 |
-
st.session_state.proceed_to_component_detector = False
|
2549 |
-
if 'proceed_to_parsing_options' not in st.session_state:
|
2550 |
-
st.session_state.proceed_to_parsing_options = False
|
2551 |
-
if 'proceed_to_api_keys' not in st.session_state:
|
2552 |
-
st.session_state.proceed_to_api_keys = False
|
2553 |
-
if 'proceed_to_space_saver' not in st.session_state:
|
2554 |
-
st.session_state.proceed_to_space_saver = False
|
2555 |
-
|
2556 |
-
|
2557 |
#################################################################################################################################################
|
2558 |
# Main ##########################################################################################################################################
|
2559 |
#################################################################################################################################################
|
2560 |
if st.session_state['is_hf']:
|
2561 |
-
if st.session_state.proceed_to_build_llm_prompt:
|
2562 |
-
|
2563 |
-
|
2564 |
main()
|
|
|
2565 |
else:
|
2566 |
if not st.session_state.private_file:
|
2567 |
create_private_file()
|
2568 |
-
elif st.session_state.proceed_to_build_llm_prompt:
|
2569 |
-
|
2570 |
elif st.session_state.proceed_to_private and not st.session_state['is_hf']:
|
2571 |
create_private_file()
|
2572 |
elif st.session_state.proceed_to_main:
|
|
|
2 |
import yaml, os, json, random, time, re, torch, random, warnings, shutil, sys
|
3 |
import seaborn as sns
|
4 |
import plotly.graph_objs as go
|
|
|
5 |
from PIL import Image
|
6 |
import pandas as pd
|
7 |
from io import BytesIO
|
|
|
14 |
from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, validate_dir
|
15 |
from vouchervision.model_maps import ModelMaps
|
16 |
from vouchervision.API_validation import APIvalidation
|
17 |
+
from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, save_uploaded_local
|
18 |
+
from vouchervision.data_project import convert_pdf_to_jpg
|
19 |
+
from vouchervision.utils_LLM import check_system_gpus
|
20 |
|
21 |
|
22 |
#################################################################################################################################################
|
23 |
# Initializations ###############################################################################################################################
|
24 |
#################################################################################################################################################
|
25 |
+
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision',initial_sidebar_state="collapsed")
|
|
|
26 |
|
27 |
# Parse the 'is_hf' argument and set it in session state
|
28 |
if 'is_hf' not in st.session_state:
|
29 |
+
try:
|
30 |
+
is_hf_os = os.getenv('IS_HF')
|
31 |
+
if is_hf_os == 1 or is_hf_os == '1' or is_hf_os or is_hf_os == 'true' or is_hf_os == 'True':
|
32 |
+
st.session_state['is_hf'] = True
|
33 |
+
else:
|
34 |
+
st.session_state['is_hf'] = False
|
35 |
+
except:
|
36 |
+
st.session_state['is_hf'] = False
|
37 |
+
print(f"is_hf {st.session_state['is_hf']}")
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
40 |
# Default YAML file path
|
41 |
if 'config' not in st.session_state:
|
42 |
st.session_state.config, st.session_state.dir_home = build_VV_config(loaded_cfg=None)
|
43 |
setup_streamlit_config(st.session_state.dir_home)
|
44 |
|
45 |
+
|
46 |
+
########################################################################################################
|
47 |
+
### Global constants ####
|
48 |
+
########################################################################################################
|
49 |
+
MAX_GALLERY_IMAGES = 20
|
50 |
+
GALLERY_IMAGE_SIZE = 96
|
51 |
+
|
52 |
+
|
53 |
+
########################################################################################################
|
54 |
+
### Init funcs ####
|
55 |
+
########################################################################################################
|
56 |
+
def does_private_file_exist():
|
57 |
+
dir_home = os.path.dirname(__file__)
|
58 |
+
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
59 |
+
return os.path.exists(path_cfg_private)
|
60 |
+
|
61 |
+
|
62 |
+
########################################################################################################
|
63 |
+
### Streamlit inits [FOR SAVE FILE] ####
|
64 |
+
########################################################################################################
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
########################################################################################################
|
70 |
+
### Streamlit inits [routing] ####
|
71 |
+
########################################################################################################
|
72 |
+
if st.session_state['is_hf']:
|
73 |
+
if 'proceed_to_main' not in st.session_state:
|
74 |
+
st.session_state.proceed_to_main = True
|
75 |
+
|
76 |
+
if 'proceed_to_private' not in st.session_state:
|
77 |
+
st.session_state.proceed_to_private = False
|
78 |
+
|
79 |
+
if 'private_file' not in st.session_state:
|
80 |
+
st.session_state.private_file = True
|
81 |
+
else:
|
82 |
+
if 'proceed_to_main' not in st.session_state:
|
83 |
+
st.session_state.proceed_to_main = False # New state variable to control the flow
|
84 |
+
|
85 |
+
if 'private_file' not in st.session_state:
|
86 |
+
st.session_state.private_file = does_private_file_exist()
|
87 |
+
if st.session_state.private_file:
|
88 |
+
st.session_state.proceed_to_main = True
|
89 |
+
|
90 |
+
if 'proceed_to_private' not in st.session_state:
|
91 |
+
st.session_state.proceed_to_private = False # New state variable to control the flow
|
92 |
+
|
93 |
+
|
94 |
+
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
95 |
+
st.session_state.proceed_to_build_llm_prompt = False # New state variable to control the flow
|
96 |
+
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
97 |
+
st.session_state.proceed_to_build_llm_prompt = False
|
98 |
+
if 'proceed_to_component_detector' not in st.session_state:
|
99 |
+
st.session_state.proceed_to_component_detector = False
|
100 |
+
if 'proceed_to_parsing_options' not in st.session_state:
|
101 |
+
st.session_state.proceed_to_parsing_options = False
|
102 |
+
if 'proceed_to_api_keys' not in st.session_state:
|
103 |
+
st.session_state.proceed_to_api_keys = False
|
104 |
+
if 'proceed_to_space_saver' not in st.session_state:
|
105 |
+
st.session_state.proceed_to_space_saver = False
|
106 |
+
if 'proceed_to_faqs' not in st.session_state:
|
107 |
+
st.session_state.proceed_to_faqs = False
|
108 |
+
|
109 |
+
|
110 |
+
########################################################################################################
|
111 |
+
### Streamlit inits [basics] ####
|
112 |
+
########################################################################################################
|
113 |
+
if 'processing_add_on' not in st.session_state:
|
114 |
+
st.session_state['processing_add_on'] = 0
|
115 |
+
|
116 |
+
|
117 |
+
if 'capability_score' not in st.session_state:
|
118 |
+
st.session_state['num_gpus'], st.session_state['gpu_dict'], st.session_state['total_vram_gb'], st.session_state['capability_score'] = check_system_gpus()
|
119 |
+
|
120 |
+
|
121 |
+
if 'formatted_json' not in st.session_state:
|
122 |
+
st.session_state['formatted_json'] = None
|
123 |
+
if 'formatted_json_WFO' not in st.session_state:
|
124 |
+
st.session_state['formatted_json_WFO'] = None
|
125 |
+
if 'formatted_json_GEO' not in st.session_state:
|
126 |
+
st.session_state['formatted_json_GEO'] = None
|
127 |
+
|
128 |
+
|
129 |
+
if 'lacks_GPU' not in st.session_state:
|
130 |
+
st.session_state['lacks_GPU'] = not torch.cuda.is_available()
|
131 |
+
|
132 |
+
|
133 |
+
if 'API_key_validation' not in st.session_state:
|
134 |
+
st.session_state['API_key_validation'] = False
|
135 |
+
if 'API_checked' not in st.session_state:
|
136 |
+
st.session_state['API_checked'] = False
|
137 |
+
if 'API_rechecked' not in st.session_state:
|
138 |
+
st.session_state['API_rechecked'] = False
|
139 |
+
|
140 |
+
|
141 |
+
if 'present_annotations' not in st.session_state:
|
142 |
+
st.session_state['present_annotations'] = None
|
143 |
+
if 'missing_annotations' not in st.session_state:
|
144 |
+
st.session_state['missing_annotations'] = None
|
145 |
+
if 'date_of_check' not in st.session_state:
|
146 |
+
st.session_state['date_of_check'] = None
|
147 |
+
|
148 |
+
|
149 |
+
if 'json_report' not in st.session_state:
|
150 |
+
st.session_state['json_report'] = False
|
151 |
+
if 'hold_output' not in st.session_state:
|
152 |
+
st.session_state['hold_output'] = False
|
153 |
+
|
154 |
+
|
155 |
+
if 'cost_openai' not in st.session_state:
|
156 |
+
st.session_state['cost_openai'] = None
|
157 |
+
if 'cost_azure' not in st.session_state:
|
158 |
+
st.session_state['cost_azure'] = None
|
159 |
+
if 'cost_google' not in st.session_state:
|
160 |
+
st.session_state['cost_google'] = None
|
161 |
+
if 'cost_mistral' not in st.session_state:
|
162 |
+
st.session_state['cost_mistral'] = None
|
163 |
+
if 'cost_local' not in st.session_state:
|
164 |
+
st.session_state['cost_local'] = None
|
165 |
+
|
166 |
+
|
167 |
+
if 'settings_filename' not in st.session_state:
|
168 |
+
st.session_state['settings_filename'] = None
|
169 |
+
if 'loaded_settings_filename' not in st.session_state:
|
170 |
+
st.session_state['loaded_settings_filename'] = None
|
171 |
+
if 'zip_filepath' not in st.session_state:
|
172 |
+
st.session_state['zip_filepath'] = None
|
173 |
+
|
174 |
+
|
175 |
+
########################################################################################################
|
176 |
+
### Streamlit inits [prompt builder] ####
|
177 |
+
########################################################################################################
|
178 |
+
# These are the fields that are in SLTPvA that are not required by another parsing valication function:
|
179 |
+
# "identifiedBy": "M.W. Lyon, Jr.",
|
180 |
+
# "recordedBy": "University of Michigan Herbarium",
|
181 |
+
# "recordNumber": "",
|
182 |
+
# "habitat": "wet subdunal woods",
|
183 |
+
# "occurrenceRemarks": "Indiana : Porter Co.",
|
184 |
+
# "degreeOfEstablishment": "",
|
185 |
+
# "minimumElevationInMeters": "",
|
186 |
+
# "maximumElevationInMeters": ""
|
187 |
+
if 'required_fields' not in st.session_state:
|
188 |
+
st.session_state['required_fields'] = ['catalogNumber','order','family','scientificName',
|
189 |
+
'scientificNameAuthorship','genus','subgenus','specificEpithet','infraspecificEpithet',
|
190 |
+
'verbatimEventDate','eventDate',
|
191 |
+
'country','stateProvince','county','municipality','locality','decimalLatitude','decimalLongitude','verbatimCoordinates',]
|
192 |
+
if 'prompt_info' not in st.session_state:
|
193 |
+
st.session_state['prompt_info'] = {}
|
194 |
+
if 'rules' not in st.session_state:
|
195 |
+
st.session_state['rules'] = {}
|
196 |
+
|
197 |
+
|
198 |
+
########################################################################################################
|
199 |
+
### Streamlit inits [gallery] ####
|
200 |
+
########################################################################################################
|
201 |
if 'uploader_idk' not in st.session_state:
|
202 |
st.session_state['uploader_idk'] = 1
|
203 |
if 'input_list_small' not in st.session_state:
|
|
|
219 |
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
220 |
validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
|
221 |
|
|
|
|
|
222 |
|
223 |
|
224 |
|
225 |
+
########################################################################################################
|
226 |
+
### CONTENT [] ####
|
227 |
+
########################################################################################################
|
228 |
def content_input_images(col_left, col_right):
|
229 |
st.write('---')
|
230 |
# col1, col2 = st.columns([2,8])
|
|
|
243 |
if st.session_state.is_hf:
|
244 |
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
245 |
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
246 |
+
uploaded_files = st.file_uploader("Upload Images", type=['jpg', 'jpeg','pdf'], accept_multiple_files=True, key=st.session_state['uploader_idk'])
|
247 |
st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
|
248 |
|
249 |
with col_right:
|
|
|
252 |
# Clear input image gallery and input list
|
253 |
clear_image_gallery()
|
254 |
|
|
|
255 |
for uploaded_file in uploaded_files:
|
256 |
+
# Determine the file type
|
257 |
+
if uploaded_file.name.lower().endswith('.pdf'):
|
258 |
+
# Handle PDF files
|
259 |
+
file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
|
260 |
+
# Convert each page of the PDF to an image
|
261 |
+
n_pages = convert_pdf_to_jpg(file_path, st.session_state['dir_uploaded_images'], dpi=st.session_state.config['leafmachine']['project']['dir_images_local'])
|
262 |
+
# Update the input list for each page image
|
263 |
+
converted_files = os.listdir(st.session_state['dir_uploaded_images'])
|
264 |
+
|
265 |
+
for file_name in converted_files:
|
266 |
+
if file_name.lower().endswith('.jpg'):
|
267 |
+
jpg_file_path = os.path.join(st.session_state['dir_uploaded_images'], file_name)
|
268 |
+
st.session_state['input_list'].append(jpg_file_path)
|
269 |
+
|
270 |
+
# Optionally, create a thumbnail for the gallery
|
271 |
+
img = Image.open(jpg_file_path)
|
272 |
+
img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
|
273 |
+
file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
|
274 |
+
st.session_state['input_list_small'].append(file_path_small)
|
275 |
+
else:
|
276 |
+
# Handle JPG/JPEG files (existing process)
|
277 |
+
file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
|
278 |
+
st.session_state['input_list'].append(file_path)
|
279 |
+
img = Image.open(file_path)
|
280 |
+
img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
|
281 |
+
file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
|
282 |
+
st.session_state['input_list_small'].append(file_path_small)
|
283 |
+
|
284 |
+
# After processing all files
|
285 |
+
st.info(f"Processing images from {st.session_state.config['leafmachine']['project']['dir_images_local']}")
|
286 |
|
287 |
if st.session_state['input_list_small']:
|
288 |
if len(st.session_state['input_list_small']) > MAX_GALLERY_IMAGES:
|
|
|
320 |
st.session_state['dir_images_local_TEMP'] = st.session_state.config['leafmachine']['project']['dir_images_local']
|
321 |
print("rerun")
|
322 |
st.rerun()
|
|
|
323 |
|
324 |
def list_jpg_files(directory_path):
|
325 |
jpg_count = 0
|
|
|
412 |
st.session_state['input_list_small'].append(file_path_small)
|
413 |
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
def refresh():
|
416 |
st.session_state['uploader_idk'] += 1
|
417 |
st.write('')
|
418 |
|
419 |
|
420 |
|
421 |
+
|
422 |
+
|
423 |
# def display_image_gallery():
|
424 |
# # Initialize the container
|
425 |
# con_image = st.empty()
|
|
|
660 |
|
661 |
|
662 |
|
663 |
+
|
|
|
|
|
|
|
664 |
|
665 |
|
666 |
|
|
|
1112 |
# st.session_state.private_file = does_private_file_exist()
|
1113 |
|
1114 |
# Function to load a YAML file and update session_state
|
1115 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1116 |
|
1117 |
### Updated to match HF version
|
1118 |
# def save_prompt_yaml(filename):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
|
1121 |
|
1122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1123 |
def show_header_welcome():
|
1124 |
st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
|
1125 |
st.session_state.logo = Image.open(st.session_state.logo_path)
|
|
|
1297 |
with col_run_4:
|
1298 |
with st.expander("View Messages and Updates"):
|
1299 |
st.info("***Note:*** If you use VoucherVision frequently, you can change the default values that are auto-populated in the form below. In a text editor or IDE, edit the first few rows in the file `../VoucherVision/vouchervision/VoucherVision_Config_Builder.py`")
|
1300 |
+
st.info("Please enable LeafMachine2 collage for full-sized images of herbarium vouchers, you will get better results!")
|
1301 |
|
1302 |
col_test = st.container()
|
1303 |
|
|
|
1307 |
col_json, col_json_WFO, col_json_GEO, col_json_map = st.columns([2, 2, 2, 2])
|
1308 |
|
1309 |
with col_run_info_1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1310 |
# Progress
|
1311 |
overall_progress_bar = st.progress(0)
|
1312 |
text_overall = st.empty() # Placeholder for current step name
|
|
|
1314 |
batch_progress_bar = st.progress(0)
|
1315 |
text_batch = st.empty() # Placeholder for current step name
|
1316 |
progress_report = ProgressReport(overall_progress_bar, batch_progress_bar, text_overall, text_batch)
|
|
|
1317 |
st.session_state['hold_output'] = st.toggle('View Final Transcription')
|
1318 |
|
1319 |
with col_logo:
|
1320 |
show_header_welcome()
|
1321 |
|
1322 |
with col_run_1:
|
|
|
1323 |
N_STEPS = 6
|
1324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1325 |
if check_if_usable(is_hf=st.session_state['is_hf']):
|
1326 |
b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
|
1327 |
if st.session_state['processing_add_on'] == 0:
|
|
|
1345 |
total_cost = 0.00
|
1346 |
n_failed_OCR = 0
|
1347 |
n_failed_LLM_calls = 0
|
1348 |
+
# try:
|
1349 |
+
voucher_vision_output = voucher_vision(None,
|
1350 |
+
st.session_state.dir_home,
|
1351 |
+
path_custom_prompts,
|
1352 |
+
None,
|
1353 |
+
progress_report,
|
1354 |
+
st.session_state['json_report'],
|
1355 |
+
path_api_cost=os.path.join(st.session_state.dir_home,'api_cost','api_cost.yaml'),
|
1356 |
+
is_hf = st.session_state['is_hf'],
|
1357 |
+
is_real_run=True)
|
1358 |
+
st.session_state['formatted_json'] = voucher_vision_output['last_JSON_response']
|
1359 |
+
st.session_state['formatted_json_WFO'] = voucher_vision_output['final_WFO_record']
|
1360 |
+
st.session_state['formatted_json_GEO'] = voucher_vision_output['final_GEO_record']
|
1361 |
+
total_cost = voucher_vision_output['total_cost']
|
1362 |
+
n_failed_OCR = voucher_vision_output['n_failed_OCR']
|
1363 |
+
n_failed_LLM_calls = voucher_vision_output['n_failed_LLM_calls']
|
1364 |
+
st.session_state['zip_filepath'] = voucher_vision_output['zip_filepath']
|
1365 |
+
# st.balloons()
|
1366 |
+
|
1367 |
+
# except Exception as e:
|
1368 |
+
# with col_run_4:
|
1369 |
+
# st.error(f"Transcription failed. Error: {e}")
|
1370 |
|
1371 |
if n_failed_OCR > 0:
|
1372 |
with col_run_4:
|
|
|
1403 |
with ct_left:
|
1404 |
st.button("Refresh", on_click=refresh, use_container_width=True)
|
1405 |
with ct_right:
|
1406 |
+
# st.page_link(os.path.join(os.path.dirname(__file__),"pages","faqs.py"), label="FAQs", icon="❔")
|
1407 |
+
st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
|
1408 |
+
|
1409 |
+
# if st.button('FAQs', use_container_width=True):
|
1410 |
+
# st.session_state.proceed_to_faqs = True
|
1411 |
+
# st.session_state.proceed_to_main = False
|
1412 |
+
# st.rerun()
|
1413 |
|
1414 |
# with col_run_2:
|
1415 |
# if st.button("Test GPT"):
|
|
|
1486 |
|
1487 |
|
1488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1489 |
def content_project_settings(col):
|
1490 |
### Project
|
1491 |
with col:
|
|
|
1575 |
st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
|
1576 |
|
1577 |
with col_prompt_2:
|
1578 |
+
# if st.button("Build Custom LLM Prompt"):
|
1579 |
+
# st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
|
1580 |
+
st.page_link(os.path.join("pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
|
1581 |
+
|
1582 |
|
1583 |
st.header('LLM Version')
|
1584 |
col_llm_1, col_llm_2 = st.columns([4,2])
|
|
|
1614 |
st.rerun()
|
1615 |
|
1616 |
|
|
|
|
|
1617 |
|
1618 |
+
def adjust_ocr_options_based_on_capability(capability_score):
|
1619 |
+
llava_models_requirements = {
|
1620 |
+
"liuhaotian/llava-v1.6-mistral-7b": {"full": 18, "4bit": 9},
|
1621 |
+
"liuhaotian/llava-v1.6-34b": {"full": 70, "4bit": 25},
|
1622 |
+
"liuhaotian/llava-v1.6-vicuna-13b": {"full": 33, "4bit": 15},
|
1623 |
+
"liuhaotian/llava-v1.6-vicuna-7b": {"full": 20, "4bit": 10},
|
1624 |
+
}
|
1625 |
+
if capability_score == 'no_gpu':
|
1626 |
+
return False
|
1627 |
+
else:
|
1628 |
+
capability_score_n = int(capability_score.split("_")[1].split("GB")[0])
|
1629 |
+
supported_models = [model for model, reqs in llava_models_requirements.items()
|
1630 |
+
if reqs["full"] <= capability_score_n or reqs["4bit"] <= capability_score_n]
|
1631 |
+
|
1632 |
+
# If no models are supported, disable the LLaVA option
|
1633 |
+
if not supported_models:
|
1634 |
+
# Assuming the LLaVA option is the last in your list
|
1635 |
+
return False # Indicate LLaVA is not supported
|
1636 |
+
return True # Indicate LLaVA is supported
|
1637 |
+
|
1638 |
+
|
1639 |
+
|
1640 |
+
def content_ocr_method():
|
1641 |
st.write("---")
|
1642 |
+
st.header('OCR Methods')
|
1643 |
+
with st.expander("Read about available OCR methods"):
|
1644 |
+
st.subheader("Overview")
|
1645 |
+
st.markdown("""VoucherVision can use the `Google Vision API`, `CRAFT` text detection + `trOCR`, and all `LLaVA v1.6` models.
|
1646 |
+
VoucherVision sends the OCR inside of the LLM prompt. We have found that sending multiple copies, or multiple version of
|
1647 |
+
the OCR text to the LLM helps the LLM maintain focus on the OCR text -- our prompts are quite long and the OCR text is reletively short.
|
1648 |
+
Below you can choose the OCR method/s. You can 'stack' all of the methods if you want, which may improve results because
|
1649 |
+
different OCR methods have different strengths, giving the LLM more information to work with. Alternative.y, you can select a single method and
|
1650 |
+
send 2 copies to the LLM by enabling that option below.""")
|
1651 |
+
st.subheader("Google Vision API")
|
1652 |
+
st.markdown("""`Google Vision API` provides several OCR methods. We use the `document_text_detection()` service, designed to handle dense text blocks.
|
1653 |
+
The `Handwritten` option CAN also be used for printed and mixed labels, but it is also optimized for handwriting. `Handwritten` uses the Google Vision Beta service.
|
1654 |
+
This is the recommended default OCR method. `Printed` uses the regular Google Vision service and works well for general use.
|
1655 |
+
You can also supplement Google Vision OCR by enabling trOCR, which is optimized for handwriting. trOCR requires segmented word images, which is provided as part
|
1656 |
+
of the Google Vision metadata. trOCR does not require a GPU, but it runs *much* faster with a GPU.""")
|
1657 |
+
st.subheader("LLaVA")
|
1658 |
+
st.markdown("""`LLaVA` can replace Google Vision APIs. It requires the use of LeafMachine2 collage, or images that are majority text. It may struggle with very
|
1659 |
+
long texts. LLaVA models are multimodal, meaning that we can upload the image and the model will transcribe (and even parse) the text all at once. With VoucherVision, we
|
1660 |
+
support 4 different LLaVA models of varying sizes, some are much more capable than others. These models tend to outperform all other OCR methods for handwriting.
|
1661 |
+
LLaVA models are run locally and require powerful GPUs to implement. While LLaVA models are capable of handling both the OCR and text parsing tasks all in one step,
|
1662 |
+
this option only uses LLaVA to transcribe all of the text in the image and still uses a separate LLM to parse text in to categories. """)
|
1663 |
+
st.subheader("CRAFT + trOCR")
|
1664 |
+
st.markdown("""This pairing can replace Google Vision APIs and is computationally lighter than LLaVA. `CRAFT` locates text, segments lines of text, and feeds the segmentations
|
1665 |
+
to the `trOCR` transformer model. This pairing requires at least an 8 GB GPU. trOCR is a Microsoft model optimized for handwriting. The base model is not as accurate as
|
1666 |
+
LLaVA or Google Vision, but if you have a trOCR-based model, let us know and we will add support.""")
|
1667 |
+
|
1668 |
+
c1, c2 = st.columns([4,4])
|
1669 |
+
|
1670 |
+
# Check if LLaVA models are supported based on capability score
|
1671 |
+
llava_supported = adjust_ocr_options_based_on_capability(st.session_state.capability_score)
|
1672 |
+
if llava_supported:
|
1673 |
+
st.success("LLaVA models are supported on this computer")
|
1674 |
+
else:
|
1675 |
+
st.warning("LLaVA models are NOT supported on this computer. Requires a GPU with at least 12 GB of VRAM.")
|
1676 |
+
|
1677 |
demo_text_h = f"Google_OCR_Handwriting:\nHERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 TX 11 Ilowers pink UNIVERSITE HERBARIUM MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
|
1678 |
demo_text_tr = f"trOCR:\nherbarium of marcus w. lyon jr. : : : tracaulon sagittatum indiana porter co. incal springs TX 11 Ilowers pink 1439649 copyright reserved D H U Q "
|
1679 |
demo_text_p = f"Google_OCR_Printed:\nTracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 Ilowers pink 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
|
|
|
1682 |
demo_text_trh = demo_text_h + '\n' + demo_text_tr
|
1683 |
demo_text_trp = demo_text_p + '\n' + demo_text_tr
|
1684 |
|
1685 |
+
options = ["Google Vision Handwritten", "Google Vision Printed", "CRAFT + trOCR","LLaVA"]
|
1686 |
+
options_llava = ["llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",]
|
1687 |
+
options_llava_bit = ["full", "4bit",]
|
1688 |
+
captions_llava = [
|
1689 |
+
"Full Model: 18 GB VRAM, 4-bit: 9 GB VRAM",
|
1690 |
+
"Full Model: 70 GB VRAM, 4-bit: 25 GB VRAM",
|
1691 |
+
"Full Model: 33 GB VRAM, 4-bit: 15 GB VRAM",
|
1692 |
+
"Full Model: 20 GB VRAM, 4-bit: 10 GB VRAM",
|
1693 |
+
]
|
1694 |
+
captions_llava_bit = ["Full Model","4-bit Quantization",]
|
1695 |
+
# Get the current OCR option from session state
|
1696 |
+
OCR_option = st.session_state.config['leafmachine']['project']['OCR_option']
|
1697 |
+
OCR_option_llava = st.session_state.config['leafmachine']['project']['OCR_option_llava']
|
1698 |
+
OCR_option_llava_bit = st.session_state.config['leafmachine']['project']['OCR_option_llava_bit']
|
1699 |
+
double_OCR = st.session_state.config['leafmachine']['project']['double_OCR']
|
1700 |
+
|
1701 |
+
# Map the OCR option to the index in options list
|
1702 |
+
# You need to define the mapping based on your application's logic
|
1703 |
+
default_index = 0 # Default to 0 if option not found
|
1704 |
+
default_index_llava = 0 # Default to 0 if option not found
|
1705 |
+
default_index_llava_bit = 0
|
1706 |
+
with c1:
|
1707 |
+
st.subheader("API Methods (Google Vision)")
|
1708 |
+
st.write("Using APIs for OCR allows VoucherVision to run on most computers.")
|
1709 |
+
|
1710 |
+
st.session_state.config['leafmachine']['project']['double_OCR'] = st.checkbox(label="Send 2 copies of the OCR to the LLM",
|
1711 |
+
help="This can help the LLMs focus attention on the OCR and not get lost in the longer instruction text",
|
1712 |
+
value=double_OCR)
|
1713 |
+
|
1714 |
+
# Create the radio button
|
1715 |
+
# OCR_option_select = st.radio(
|
1716 |
+
# "Select the OCR Method",
|
1717 |
+
# options,
|
1718 |
+
# index=default_index,
|
1719 |
+
# help="",captions=captions,
|
1720 |
+
# )
|
1721 |
+
default_values = [options[default_index]]
|
1722 |
+
OCR_option_select = st.multiselect(
|
1723 |
+
"Select the OCR Method(s)",
|
1724 |
+
options=options,
|
1725 |
+
default=default_values,
|
1726 |
+
help="Select one or more OCR methods."
|
1727 |
+
)
|
1728 |
+
# st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option_select
|
1729 |
+
|
1730 |
+
# Handling multiple selections (Example logic)
|
1731 |
+
OCR_options = {
|
1732 |
+
"Google Vision Handwritten": 'hand',
|
1733 |
+
"Google Vision Printed": 'normal',
|
1734 |
+
"CRAFT + trOCR": 'CRAFT',
|
1735 |
+
"LLaVA": 'LLaVA',
|
1736 |
+
}
|
1737 |
+
|
1738 |
+
# Map selected options to their corresponding internal representations
|
1739 |
+
selected_OCR_options = [OCR_options[option] for option in OCR_option_select]
|
1740 |
+
|
1741 |
+
# Assuming you need to use these mapped values elsewhere in your application
|
1742 |
+
st.session_state.config['leafmachine']['project']['OCR_option'] = selected_OCR_options
|
1743 |
+
|
1744 |
+
|
1745 |
+
with c2:
|
1746 |
+
st.subheader("Local Methods")
|
1747 |
+
st.write("Local methods are free, but require a capable GPU. ")
|
1748 |
+
|
1749 |
+
|
1750 |
+
st.write("Supplement Google Vision OCR with trOCR (handwriting OCR) using `microsoft/trocr-base-handwritten`. This option requires Google Vision API and a GPU.")
|
1751 |
+
if 'CRAFT' in selected_OCR_options:
|
1752 |
+
do_use_trOCR = st.checkbox("Enable trOCR", value=True, key="Enable trOCR1",disabled=True)#,disabled=st.session_state['lacks_GPU'])
|
1753 |
+
else:
|
1754 |
+
do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],key="Enable trOCR2")#,disabled=st.session_state['lacks_GPU'])
|
1755 |
+
st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
|
1756 |
+
|
1757 |
+
if 'LLaVA' in selected_OCR_options:
|
1758 |
+
OCR_option_llava = st.radio(
|
1759 |
+
"Select the LLaVA version",
|
1760 |
+
options_llava,
|
1761 |
+
index=default_index_llava,
|
1762 |
+
help="",captions=captions_llava,
|
1763 |
+
)
|
1764 |
+
st.session_state.config['leafmachine']['project']['OCR_option_llava'] = OCR_option_llava
|
1765 |
+
|
1766 |
+
OCR_option_llava_bit = st.radio(
|
1767 |
+
"Select the LLaVA quantization level",
|
1768 |
+
options_llava_bit,
|
1769 |
+
index=default_index_llava_bit,
|
1770 |
+
help="",captions=captions_llava_bit,
|
1771 |
+
)
|
1772 |
+
st.session_state.config['leafmachine']['project']['OCR_option_llava_bit'] = OCR_option_llava_bit
|
1773 |
+
|
1774 |
+
|
1775 |
+
|
1776 |
+
# st.markdown("Below is an example of what the LLM would see given the choice of OCR ensemble. One, two, or three version of OCR can be fed into the LLM prompt. Typically, 'printed + handwritten' works well. If you have a GPU then you can enable trOCR.")
|
1777 |
+
# if (OCR_option == 'hand') and not do_use_trOCR:
|
1778 |
+
# st.text_area(label='Handwritten/Printed',placeholder=demo_text_h,disabled=True, label_visibility='visible', height=150)
|
1779 |
+
# elif (OCR_option == 'normal') and not do_use_trOCR:
|
1780 |
+
# st.text_area(label='Printed',placeholder=demo_text_p,disabled=True, label_visibility='visible', height=150)
|
1781 |
+
# elif (OCR_option == 'both') and not do_use_trOCR:
|
1782 |
+
# st.text_area(label='Handwritten/Printed + Printed',placeholder=demo_text_b,disabled=True, label_visibility='visible', height=150)
|
1783 |
+
# elif (OCR_option == 'both') and do_use_trOCR:
|
1784 |
+
# st.text_area(label='Handwritten/Printed + Printed + trOCR',placeholder=demo_text_trb,disabled=True, label_visibility='visible', height=150)
|
1785 |
+
# elif (OCR_option == 'normal') and do_use_trOCR:
|
1786 |
+
# st.text_area(label='Printed + trOCR',placeholder=demo_text_trp,disabled=True, label_visibility='visible', height=150)
|
1787 |
+
# elif (OCR_option == 'hand') and do_use_trOCR:
|
1788 |
+
# st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
|
1789 |
+
|
1790 |
+
|
1791 |
+
|
1792 |
+
def content_collage_overlay():
|
1793 |
+
st.write("---")
|
1794 |
+
col_collage, col_overlay = st.columns([4,4])
|
1795 |
+
|
1796 |
+
|
1797 |
+
|
1798 |
with col_collage:
|
1799 |
st.header('LeafMachine2 Label Collage')
|
1800 |
+
st.info("NOTE: We strongly recommend enabling LeafMachine2 cropping if your images are full sized herbarium sheet. Often, the OCR algorithm struggles with full sheets, but works well with the collage images. We have disabled the collage by default for this Hugging Face Space because the Space lacks a GPU and the collage creation takes a bit longer.")
|
1801 |
default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
|
1802 |
st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
|
1803 |
+
st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
|
1804 |
|
1805 |
|
1806 |
option_selected_crops = st.multiselect(label="Components to crop",
|
|
|
1817 |
with st.expander(":frame_with_picture: View an example of the LeafMachine2 collage image"):
|
1818 |
st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="PNG")
|
1819 |
# st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="JPEG")
|
|
|
|
|
1820 |
|
1821 |
with col_overlay:
|
1822 |
st.header('OCR Overlay Image')
|
|
|
|
|
|
|
|
|
|
|
|
|
1823 |
|
1824 |
st.write('This will plot bounding boxes around all text that Google Vision was able to detect. If there are no boxes around text, then the OCR failed, so that missing text will not be seen by the LLM when it is creating the JSON object. The created image will be viewable in the VoucherVisionEditor.')
|
1825 |
|
1826 |
do_create_OCR_helper_image = st.checkbox("Create image showing an overlay of the OCR detections",value=st.session_state.config['leafmachine']['do_create_OCR_helper_image'],disabled=True)
|
1827 |
st.session_state.config['leafmachine']['do_create_OCR_helper_image'] = do_create_OCR_helper_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1828 |
|
1829 |
if "demo_overlay" not in st.session_state:
|
1830 |
# ocr = os.path.join(st.session_state.dir_home,'demo', 'ba','ocr.png')
|
|
|
1874 |
st.subheader('Compute Options')
|
1875 |
st.session_state.config['leafmachine']['project']['num_workers'] = st.number_input("Number of CPU workers", value=st.session_state.config['leafmachine']['project'].get('num_workers', 1), disabled=False)
|
1876 |
st.session_state.config['leafmachine']['project']['batch_size'] = st.number_input("Batch size", value=st.session_state.config['leafmachine']['project'].get('batch_size', 500), help='Sets the batch size for the LeafMachine2 cropping. If computer RAM is filled, lower this value to ~100.')
|
1877 |
+
st.session_state.config['leafmachine']['project']['pdf_conversion_dpi'] = st.number_input("PDF conversion DPI", value=st.session_state.config['leafmachine']['project'].get('pdf_conversion_dpi', 100), help='DPI of the JPG created from the page of a PDF. 100 should be fine for most cases, but 200 or 300 might be better for large images.')
|
1878 |
+
|
1879 |
with col_processing_2:
|
1880 |
st.subheader('Filename Prefix Handling')
|
1881 |
st.session_state.config['leafmachine']['project']['prefix_removal'] = st.text_input("Remove prefix from catalog number", st.session_state.config['leafmachine']['project'].get('prefix_removal', ''),placeholder="e.g. MICH-V-")
|
|
|
1884 |
|
1885 |
### Logging and Image Validation - col_v1
|
1886 |
st.write("---")
|
|
|
1887 |
col_v1, col_v2 = st.columns(2)
|
1888 |
+
|
1889 |
with col_v1:
|
1890 |
+
st.header('Logging and Image Validation')
|
1891 |
option_check_illegal = st.checkbox("Check for illegal filenames", value=st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'])
|
1892 |
st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'] = option_check_illegal
|
1893 |
+
|
1894 |
+
option_skip_vertical = st.checkbox("Skip vertical image requirement (e.g. horizontal PDFs)", value=st.session_state.config['leafmachine']['do']['skip_vertical'],help='LeafMachine2 label collage requires images to have vertical aspect ratios for stability. If your input images have a horizonatal aspect ratio, try skipping the vertical requirement first, look for strange behavior, and then reassess. If your image/PDFs are already closeups and you do not need the collage, then skipping the vertical requirement is the right choice.')
|
1895 |
+
st.session_state.config['leafmachine']['do']['skip_vertical'] = option_skip_vertical
|
1896 |
+
|
1897 |
st.session_state.config['leafmachine']['do']['check_for_corrupt_images_make_vertical'] = st.checkbox("Check for corrupt images", st.session_state.config['leafmachine']['do'].get('check_for_corrupt_images_make_vertical', True),disabled=True)
|
1898 |
|
1899 |
st.session_state.config['leafmachine']['print']['verbose'] = st.checkbox("Print verbose", st.session_state.config['leafmachine']['print'].get('verbose', True))
|
1900 |
st.session_state.config['leafmachine']['print']['optional_warnings'] = st.checkbox("Show optional warnings", st.session_state.config['leafmachine']['print'].get('optional_warnings', True))
|
1901 |
+
|
|
|
1902 |
log_level = st.session_state.config['leafmachine']['logging'].get('log_level', None)
|
1903 |
log_level_display = log_level if log_level is not None else 'default'
|
1904 |
selected_log_level = st.selectbox("Logging Level", ['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'], index=['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'].index(log_level_display))
|
|
|
1908 |
else:
|
1909 |
st.session_state.config['leafmachine']['logging']['log_level'] = selected_log_level
|
1910 |
|
1911 |
+
with col_v2:
|
1912 |
+
|
1913 |
+
|
1914 |
+
print(f"Number of GPUs: {st.session_state.num_gpus}")
|
1915 |
+
print(f"GPU Details: {st.session_state.gpu_dict}")
|
1916 |
+
print(f"Total VRAM: {st.session_state.total_vram_gb} GB")
|
1917 |
+
print(f"Capability Score: {st.session_state.capability_score}")
|
1918 |
+
|
1919 |
+
st.header('System GPU Information')
|
1920 |
+
st.markdown(f"**Torch CUDA:** {torch.cuda.is_available()}")
|
1921 |
+
st.markdown(f"**Number of GPUs:** {st.session_state.num_gpus}")
|
1922 |
+
|
1923 |
+
if st.session_state.num_gpus > 0:
|
1924 |
+
st.markdown("**GPU Details:**")
|
1925 |
+
for gpu_id, vram in st.session_state.gpu_dict.items():
|
1926 |
+
st.text(f"{gpu_id}: {vram}")
|
1927 |
+
|
1928 |
+
st.markdown(f"**Total VRAM:** {st.session_state.total_vram_gb} GB")
|
1929 |
+
st.markdown(f"**Capability Score:** {st.session_state.capability_score}")
|
1930 |
+
else:
|
1931 |
+
st.warning("No GPUs detected in the system.")
|
1932 |
+
|
1933 |
|
1934 |
|
1935 |
def content_tab_domain():
|
|
|
1996 |
expense_report = st.session_state.expense_report
|
1997 |
st.header('Expense Report Summary')
|
1998 |
|
1999 |
+
if not expense_summary:
|
2000 |
+
st.warning('No expense report data available.')
|
2001 |
+
else:
|
2002 |
st.metric(label="Total Cost", value=f"${round(expense_summary['total_cost_sum'], 4):,}")
|
2003 |
col1, col2 = st.columns(2)
|
2004 |
|
|
|
2092 |
pie_chart.update_traces(marker=dict(colors=colors),)
|
2093 |
st.plotly_chart(pie_chart, use_container_width=True)
|
2094 |
|
|
|
|
|
|
|
|
|
2095 |
|
2096 |
def content_less_used():
|
2097 |
st.write('---')
|
2098 |
st.write(':octagonal_sign: ***NOTE:*** Settings below are not relevant for most projects. Some settings below may not be reflected in saved settings files and would need to be set each time.')
|
2099 |
|
2100 |
+
|
2101 |
#################################################################################################################################################
|
2102 |
# Sidebar #######################################################################################################################################
|
2103 |
#################################################################################################################################################
|
2104 |
def sidebar_content():
|
2105 |
+
# st.page_link(os.path.join(os.path.dirname(__file__),'app.py'), label="Home", icon="🏠")
|
2106 |
+
# st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
|
2107 |
+
# st.page_link("pages/page_2.py", label="Page 2", icon="2️⃣", disabled=True)
|
2108 |
+
# st.page_link("http://www.google.com", label="Google", icon="🌎")
|
2109 |
+
|
2110 |
if not os.path.exists(os.path.join(st.session_state.dir_home,'expense_report')):
|
2111 |
validate_dir(os.path.join(st.session_state.dir_home,'expense_report'))
|
2112 |
expense_report_path = os.path.join(st.session_state.dir_home, 'expense_report', 'expense_report.csv')
|
|
|
2123 |
st.write('Available after first run...')
|
2124 |
|
2125 |
|
|
|
2126 |
#################################################################################################################################################
|
2127 |
# Routing Function ##############################################################################################################################
|
2128 |
#################################################################################################################################################
|
|
|
2132 |
sidebar_content()
|
2133 |
# Main App
|
2134 |
content_header()
|
|
|
2135 |
|
2136 |
col_input, col_gallery = st.columns([4,8])
|
2137 |
content_project_settings(col_input)
|
2138 |
content_input_images(col_input, col_gallery)
|
2139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2140 |
|
2141 |
col3, col4 = st.columns([1,1])
|
2142 |
with col3:
|
2143 |
content_prompt_and_llm_version()
|
2144 |
with col4:
|
2145 |
content_api_check()
|
2146 |
+
|
2147 |
+
content_ocr_method()
|
2148 |
+
|
2149 |
content_collage_overlay()
|
2150 |
content_llm_cost()
|
2151 |
content_processing_options()
|
|
|
2155 |
content_space_saver()
|
2156 |
|
2157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2158 |
#################################################################################################################################################
|
2159 |
# Main ##########################################################################################################################################
|
2160 |
#################################################################################################################################################
|
2161 |
if st.session_state['is_hf']:
|
2162 |
+
# if st.session_state.proceed_to_build_llm_prompt:
|
2163 |
+
# build_LLM_prompt_config()
|
2164 |
+
if st.session_state.proceed_to_main:
|
2165 |
main()
|
2166 |
+
|
2167 |
else:
|
2168 |
if not st.session_state.private_file:
|
2169 |
create_private_file()
|
2170 |
+
# elif st.session_state.proceed_to_build_llm_prompt:
|
2171 |
+
# build_LLM_prompt_config()
|
2172 |
elif st.session_state.proceed_to_private and not st.session_state['is_hf']:
|
2173 |
create_private_file()
|
2174 |
elif st.session_state.proceed_to_main:
|
install_dependencies.sh
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# List of packages to be installed
|
4 |
+
packages=(
|
5 |
+
wheel
|
6 |
+
gputil
|
7 |
+
streamlit
|
8 |
+
streamlit-extras
|
9 |
+
streamlit-elements==0.1.*
|
10 |
+
plotly
|
11 |
+
google-api-python-client
|
12 |
+
wikipedia
|
13 |
+
PyMuPDF
|
14 |
+
craft-text-detector
|
15 |
+
pyyaml
|
16 |
+
Pillow
|
17 |
+
bitsandbytes
|
18 |
+
accelerate
|
19 |
+
mapboxgl
|
20 |
+
pandas
|
21 |
+
matplotlib
|
22 |
+
matplotlib-inline
|
23 |
+
tqdm
|
24 |
+
openai
|
25 |
+
langchain
|
26 |
+
langchain-community
|
27 |
+
langchain-core
|
28 |
+
langchain_mistralai
|
29 |
+
langchain_openai
|
30 |
+
langchain_google_genai
|
31 |
+
langchain_experimental
|
32 |
+
jsonformer
|
33 |
+
vertexai
|
34 |
+
ctransformers
|
35 |
+
google-cloud-aiplatform
|
36 |
+
tiktoken
|
37 |
+
llama-cpp-python
|
38 |
+
openpyxl
|
39 |
+
google-generativeai
|
40 |
+
google-cloud-storage
|
41 |
+
google-cloud-vision
|
42 |
+
opencv-python
|
43 |
+
chromadb
|
44 |
+
chroma-migrate
|
45 |
+
InstructorEmbedding
|
46 |
+
transformers
|
47 |
+
sentence-transformers
|
48 |
+
seaborn
|
49 |
+
dask
|
50 |
+
psutil
|
51 |
+
py-cpuinfo
|
52 |
+
Levenshtein
|
53 |
+
fuzzywuzzy
|
54 |
+
opencage
|
55 |
+
geocoder
|
56 |
+
pycountry_convert
|
57 |
+
)
|
58 |
+
|
59 |
+
# Function to install a single package
|
60 |
+
install_package() {
|
61 |
+
package=$1
|
62 |
+
echo "Installing $package..."
|
63 |
+
pip3 install $package
|
64 |
+
if [ $? -ne 0 ]; then
|
65 |
+
echo "Failed to install $package"
|
66 |
+
exit 1
|
67 |
+
fi
|
68 |
+
}
|
69 |
+
|
70 |
+
# Install each package individually
|
71 |
+
for package in "${packages[@]}"; do
|
72 |
+
install_package $package
|
73 |
+
done
|
74 |
+
|
75 |
+
echo "All packages installed successfully."
|
76 |
+
echo "Cloning and installing LLaVA..."
|
77 |
+
|
78 |
+
|
79 |
+
cd vouchervision
|
80 |
+
git clone https://github.com/haotian-liu/LLaVA.git
|
81 |
+
cd LLaVA # Assuming you want to run pip install in the LLaVA directory
|
82 |
+
pip install -e .
|
83 |
+
git pull
|
84 |
+
pip install -e .
|
85 |
+
echo "LLaVA ready"
|
pages/faqs.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import streamlit.components.v1 as components
|
4 |
+
|
5 |
+
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VV FAQs',initial_sidebar_state="collapsed")
|
6 |
+
|
7 |
+
def display_faqs():
|
8 |
+
c1, c2, c3 = st.columns([4,6,1])
|
9 |
+
with c3:
|
10 |
+
# st.page_link(os.path.join(os.path.dirname(os.path.dirname(__file__)),'app.py'), label="Home", icon="🏠")
|
11 |
+
# st.page_link(os.path.join(os.path.dirname(os.path.dirname(__file__)),"pages","faqs.py"), label="FAQs", icon="❔")
|
12 |
+
# st.page_link(os.path.join(os.path.dirname(os.path.dirname(__file__)),"pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
|
13 |
+
st.page_link('app.py', label="Home", icon="🏠")
|
14 |
+
st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
|
15 |
+
st.page_link(os.path.join("pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
|
16 |
+
with c2:
|
17 |
+
st.write('If you would like to get more involved, have questions, would like to see additional features, then please fill out this [Google Form](https://docs.google.com/forms/d/e/1FAIpQLSe2E9zU1bPJ1BW4PMakEQFsRmLbQ0WTBI2UXHIMEFm4WbnAVw/viewform?usp=sf_link)')
|
18 |
+
components.iframe(f"https://docs.google.com/forms/d/e/1FAIpQLSe2E9zU1bPJ1BW4PMakEQFsRmLbQ0WTBI2UXHIMEFm4WbnAVw/viewform?embedded=true", height=900,scrolling=True,width=640)
|
19 |
+
|
20 |
+
with c1:
|
21 |
+
st.header('FAQs')
|
22 |
+
st.subheader('Lead Institution')
|
23 |
+
st.write('- University of Michigan')
|
24 |
+
|
25 |
+
st.subheader('Partner Institutions')
|
26 |
+
st.write('- Oregon State University')
|
27 |
+
st.write('- University of Colorado Boulder')
|
28 |
+
st.write('- Botanical Research Institute of Texas')
|
29 |
+
st.write('- Smithsonian National Museum of Natural History')
|
30 |
+
st.write('- South African National Biodiversity Institute')
|
31 |
+
st.write('- Botanischer Garten Berlin')
|
32 |
+
st.write('- Freie Universität Berlin')
|
33 |
+
st.write('- Morton Arboretum')
|
34 |
+
st.write('- Florida Museum')
|
35 |
+
st.write('- iDigBio')
|
36 |
+
st.write('**More soon!**')
|
37 |
+
|
38 |
+
display_faqs()
|
pages/prompt_builder.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, yaml
|
2 |
+
import streamlit as st
|
3 |
+
from PIL import Image
|
4 |
+
from itertools import chain
|
5 |
+
|
6 |
+
from vouchervision.model_maps import ModelMaps
|
7 |
+
from vouchervision.utils_hf import check_prompt_yaml_filename
|
8 |
+
|
9 |
+
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VV Prompt Builder',initial_sidebar_state="collapsed")
|
10 |
+
|
11 |
+
def create_download_button_yaml(file_path, selected_yaml_file, key_val):
|
12 |
+
file_label = f"Download {selected_yaml_file}"
|
13 |
+
with open(file_path, 'rb') as f:
|
14 |
+
st.download_button(
|
15 |
+
label=file_label,
|
16 |
+
data=f,
|
17 |
+
file_name=os.path.basename(file_path),
|
18 |
+
mime='application/x-yaml',use_container_width=True,key=key_val,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def upload_local_prompt_to_server(dir_prompt):
|
23 |
+
uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
|
24 |
+
if uploaded_file is not None:
|
25 |
+
# Check the file extension
|
26 |
+
file_name = uploaded_file.name
|
27 |
+
if file_name.endswith('.yaml'):
|
28 |
+
file_path = os.path.join(dir_prompt, file_name)
|
29 |
+
|
30 |
+
# Save the file
|
31 |
+
with open(file_path, 'wb') as f:
|
32 |
+
f.write(uploaded_file.getbuffer())
|
33 |
+
st.success(f"Saved file {file_name} in {dir_prompt}")
|
34 |
+
else:
|
35 |
+
st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
|
36 |
+
|
37 |
+
|
38 |
+
def save_prompt_yaml(filename, col):
|
39 |
+
yaml_content = {
|
40 |
+
'prompt_author': st.session_state['prompt_author'],
|
41 |
+
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
42 |
+
'prompt_name': st.session_state['prompt_name'],
|
43 |
+
'prompt_version': st.session_state['prompt_version'],
|
44 |
+
'prompt_description': st.session_state['prompt_description'],
|
45 |
+
'LLM': st.session_state['LLM'],
|
46 |
+
'instructions': st.session_state['instructions'],
|
47 |
+
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
48 |
+
'rules': st.session_state['rules'],
|
49 |
+
'mapping': st.session_state['mapping'],
|
50 |
+
}
|
51 |
+
|
52 |
+
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
53 |
+
filepath = os.path.join(dir_prompt, f"{filename}.yaml")
|
54 |
+
|
55 |
+
with open(filepath, 'w') as file:
|
56 |
+
yaml.safe_dump(dict(yaml_content), file, sort_keys=False)
|
57 |
+
|
58 |
+
st.success(f"Prompt saved as '{filename}.yaml'.")
|
59 |
+
|
60 |
+
with col: # added
|
61 |
+
create_download_button_yaml(filepath, filename,key_val=2456237465) # added
|
62 |
+
|
63 |
+
|
64 |
+
def load_prompt_yaml(filename):
|
65 |
+
st.session_state['user_clicked_load_prompt_yaml'] = filename
|
66 |
+
with open(filename, 'r') as file:
|
67 |
+
st.session_state['prompt_info'] = yaml.safe_load(file)
|
68 |
+
st.session_state['prompt_author'] = st.session_state['prompt_info'].get('prompt_author', st.session_state['default_prompt_author'])
|
69 |
+
st.session_state['prompt_author_institution'] = st.session_state['prompt_info'].get('prompt_author_institution', st.session_state['default_prompt_author_institution'])
|
70 |
+
st.session_state['prompt_name'] = st.session_state['prompt_info'].get('prompt_name', st.session_state['default_prompt_name'])
|
71 |
+
st.session_state['prompt_version'] = st.session_state['prompt_info'].get('prompt_version', st.session_state['default_prompt_version'])
|
72 |
+
st.session_state['prompt_description'] = st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description'])
|
73 |
+
st.session_state['instructions'] = st.session_state['prompt_info'].get('instructions', st.session_state['default_instructions'])
|
74 |
+
st.session_state['json_formatting_instructions'] = st.session_state['prompt_info'].get('json_formatting_instructions', st.session_state['default_json_formatting_instructions'] )
|
75 |
+
st.session_state['rules'] = st.session_state['prompt_info'].get('rules', {})
|
76 |
+
st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
|
77 |
+
st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
|
78 |
+
|
79 |
+
# Placeholder:
|
80 |
+
st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
|
81 |
+
|
82 |
+
|
83 |
+
def btn_load_prompt(selected_yaml_file, dir_prompt):
|
84 |
+
if selected_yaml_file:
|
85 |
+
yaml_file_path = os.path.join(dir_prompt, selected_yaml_file)
|
86 |
+
load_prompt_yaml(yaml_file_path)
|
87 |
+
elif not selected_yaml_file:
|
88 |
+
# Directly assigning default values since no file is selected
|
89 |
+
st.session_state['prompt_info'] = {}
|
90 |
+
st.session_state['prompt_author'] = st.session_state['default_prompt_author']
|
91 |
+
st.session_state['prompt_author_institution'] = st.session_state['default_prompt_author_institution']
|
92 |
+
st.session_state['prompt_name'] = st.session_state['prompt_name']
|
93 |
+
st.session_state['prompt_version'] = st.session_state['prompt_version']
|
94 |
+
st.session_state['prompt_description'] = st.session_state['default_prompt_description']
|
95 |
+
st.session_state['instructions'] = st.session_state['default_instructions']
|
96 |
+
st.session_state['json_formatting_instructions'] = st.session_state['default_json_formatting_instructions']
|
97 |
+
st.session_state['rules'] = {}
|
98 |
+
st.session_state['LLM'] = 'General Purpose'
|
99 |
+
|
100 |
+
st.session_state['assigned_columns'] = []
|
101 |
+
|
102 |
+
st.session_state['prompt_info'] = {
|
103 |
+
'prompt_author': st.session_state['prompt_author'],
|
104 |
+
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
105 |
+
'prompt_name': st.session_state['prompt_name'],
|
106 |
+
'prompt_version': st.session_state['prompt_version'],
|
107 |
+
'prompt_description': st.session_state['prompt_description'],
|
108 |
+
'instructions': st.session_state['instructions'],
|
109 |
+
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
110 |
+
'rules': st.session_state['rules'],
|
111 |
+
'mapping': st.session_state['mapping'],
|
112 |
+
'LLM': st.session_state['LLM']
|
113 |
+
}
|
114 |
+
|
115 |
+
|
116 |
+
def check_unique_mapping_assignments():
|
117 |
+
print(st.session_state['assigned_columns'])
|
118 |
+
if len(st.session_state['assigned_columns']) != len(set(st.session_state['assigned_columns'])):
|
119 |
+
st.error("Each column name must be assigned to only one category.")
|
120 |
+
return False
|
121 |
+
elif not st.session_state['assigned_columns']:
|
122 |
+
st.error("No columns have been mapped.")
|
123 |
+
return False
|
124 |
+
elif len(st.session_state['assigned_columns']) != len(st.session_state['rules'].keys()):
|
125 |
+
incomplete = [item for item in list(st.session_state['rules'].keys()) if item not in st.session_state['assigned_columns']]
|
126 |
+
st.warning(f"These columns have been mapped: {st.session_state['assigned_columns']}")
|
127 |
+
st.error(f"However, these columns must be mapped before the prompt is complete: {incomplete}")
|
128 |
+
return False
|
129 |
+
else:
|
130 |
+
st.success("Mapping confirmed.")
|
131 |
+
return True
|
132 |
+
|
133 |
+
|
134 |
+
def build_LLM_prompt_config():
|
135 |
+
col_main1, col_main2 = st.columns([10,2])
|
136 |
+
with col_main1:
|
137 |
+
st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
|
138 |
+
st.session_state.logo = Image.open(st.session_state.logo_path)
|
139 |
+
st.image(st.session_state.logo, width=250)
|
140 |
+
with col_main2:
|
141 |
+
st.page_link('app.py', label="Home", icon="🏠")
|
142 |
+
st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
|
143 |
+
st.page_link(os.path.join("pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
|
144 |
+
# st.page_link("pages/page_2.py", label="Page 2", icon="2️⃣", disabled=True)
|
145 |
+
# st.page_link("http://www.google.com", label="Google", icon="🌎")
|
146 |
+
|
147 |
+
st.session_state['assigned_columns'] = []
|
148 |
+
st.session_state['default_prompt_author'] = 'unknown'
|
149 |
+
st.session_state['default_prompt_author_institution'] = 'unknown'
|
150 |
+
st.session_state['default_prompt_name'] = 'custom_prompt'
|
151 |
+
st.session_state['default_prompt_version'] = 'v-1-0'
|
152 |
+
st.session_state['default_prompt_author_institution'] = 'unknown'
|
153 |
+
st.session_state['default_prompt_description'] = 'unknown'
|
154 |
+
st.session_state['default_LLM'] = 'General Purpose'
|
155 |
+
st.session_state['default_instructions'] = """1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
156 |
+
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
157 |
+
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
158 |
+
4. Duplicate dictionary fields are not allowed.
|
159 |
+
5. Ensure all JSON keys are in camel case.
|
160 |
+
6. Ensure new JSON field values follow sentence case capitalization.
|
161 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
162 |
+
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
163 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer."""
|
164 |
+
st.session_state['default_json_formatting_instructions'] = """This section provides rules for formatting each JSON value organized by the JSON key."""
|
165 |
+
|
166 |
+
# Start building the Streamlit app
|
167 |
+
col_prompt_main_left, ___, col_prompt_main_right = st.columns([6,1,3])
|
168 |
+
|
169 |
+
|
170 |
+
with col_prompt_main_left:
|
171 |
+
|
172 |
+
st.title("Custom LLM Prompt Builder")
|
173 |
+
st.subheader('About')
|
174 |
+
st.write("This form allows you to craft a prompt for your specific task. You can also edit the JSON yaml files directly, but please try loading the prompt back into this form to ensure that the formatting is correct. If this form cannot load your manually edited JSON yaml file, then it will not work in VoucherVision.")
|
175 |
+
st.subheader(':rainbow[How it Works]')
|
176 |
+
st.write("1. Edit this page until you are happy with your instructions. We recommend looking at the basic structure, writing down your prompt inforamtion in a Word document so that it does not randomly disappear, and then copying and pasting that info into this form once your whole prompt structure is defined.")
|
177 |
+
st.write("2. After you enter all of your prompt instructions, click 'Save' and give your file a name.")
|
178 |
+
st.write("3. This file will be saved as a yaml configuration file in the `..VoucherVision/custom_prompts` folder.")
|
179 |
+
st.write("4. When you go back the main VoucherVision page you will now see your custom prompt available in the 'Prompt Version' dropdown menu.")
|
180 |
+
st.write("5. The LLM ***only*** sees information from the 'instructions', 'rules', and 'json_formatting_instructions' sections. All other information is for versioning and integration with VoucherVisionEditor.")
|
181 |
+
|
182 |
+
st.write("---")
|
183 |
+
st.header('Load an Existing Prompt Template')
|
184 |
+
st.write("By default, this form loads the minimum required transcription fields but does not provide rules for each field. You can also load an existing prompt as a template, editing or deleting values as needed.")
|
185 |
+
|
186 |
+
dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
|
187 |
+
yaml_files = [f for f in os.listdir(dir_prompt) if f.endswith('.yaml')]
|
188 |
+
col_load_text, col_load_btn, col_load_btn2 = st.columns([8,2,2])
|
189 |
+
with col_load_text:
|
190 |
+
# Dropdown for selecting a YAML file
|
191 |
+
st.session_state['selected_yaml_file'] = st.selectbox('Select a prompt .YAML file to load:', [''] + yaml_files)
|
192 |
+
with col_load_btn:
|
193 |
+
st.write('##')
|
194 |
+
# Button to load the selected prompt
|
195 |
+
st.button('Load Prompt', on_click=btn_load_prompt, args=[st.session_state['selected_yaml_file'], dir_prompt],use_container_width=True)
|
196 |
+
|
197 |
+
with col_load_btn2:
|
198 |
+
if st.session_state['selected_yaml_file']:
|
199 |
+
# Construct the full path to the file
|
200 |
+
download_file_path = os.path.join(dir_prompt, st.session_state['selected_yaml_file'] )
|
201 |
+
# Create the download button
|
202 |
+
st.write('##')
|
203 |
+
create_download_button_yaml(download_file_path, st.session_state['selected_yaml_file'],key_val=345798)
|
204 |
+
|
205 |
+
# Prompt Author Information
|
206 |
+
st.write("---")
|
207 |
+
st.header("Prompt Author Information")
|
208 |
+
st.write("We value community contributions! Please provide your name(s) (or pseudonym if you prefer) for credit. If you leave this field blank, it will say 'unknown'.")
|
209 |
+
if 'prompt_author' not in st.session_state:# != st.session_state['default_prompt_author']:
|
210 |
+
st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['default_prompt_author'],key=1111)
|
211 |
+
else:
|
212 |
+
st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['prompt_author'],key=1112)
|
213 |
+
|
214 |
+
# Institution
|
215 |
+
st.write("Please provide your institution name. If you leave this field blank, it will say 'unknown'.")
|
216 |
+
if 'prompt_author_institution' not in st.session_state:
|
217 |
+
st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['default_prompt_author_institution'],key=1113)
|
218 |
+
else:
|
219 |
+
st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['prompt_author_institution'],key=1114)
|
220 |
+
|
221 |
+
# Prompt name
|
222 |
+
st.write("Please provide a simple name for your prompt. If you leave this field blank, it will say 'custom_prompt'.")
|
223 |
+
if 'prompt_name' not in st.session_state:
|
224 |
+
st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['default_prompt_name'],key=1115)
|
225 |
+
else:
|
226 |
+
st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['prompt_name'],key=1116)
|
227 |
+
|
228 |
+
# Prompt verion
|
229 |
+
st.write("Please provide a version identifier for your prompt. If you leave this field blank, it will say 'v-1-0'.")
|
230 |
+
if 'prompt_version' not in st.session_state:
|
231 |
+
st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['default_prompt_version'],key=1117)
|
232 |
+
else:
|
233 |
+
st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['prompt_version'],key=1118)
|
234 |
+
|
235 |
+
|
236 |
+
st.write("Please provide a description of your prompt and its intended task. Is it designed for a specific collection? Taxa? Database structure?")
|
237 |
+
if 'prompt_description' not in st.session_state:
|
238 |
+
st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['default_prompt_description'],key=1119)
|
239 |
+
else:
|
240 |
+
st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['prompt_description'],key=11111)
|
241 |
+
|
242 |
+
st.write('---')
|
243 |
+
st.header("Set LLM Model Type")
|
244 |
+
# Define the options for the dropdown
|
245 |
+
llm_options_general = ["General Purpose",
|
246 |
+
"OpenAI GPT Models","Google PaLM2 Models","Google Gemini Models","MistralAI Models",]
|
247 |
+
llm_options_all = ModelMaps.get_models_gui_list()
|
248 |
+
|
249 |
+
if 'LLM' not in st.session_state:
|
250 |
+
st.session_state['LLM'] = st.session_state['default_LLM']
|
251 |
+
|
252 |
+
if st.session_state['LLM']:
|
253 |
+
llm_options = llm_options_general + llm_options_all + [st.session_state['LLM']]
|
254 |
+
else:
|
255 |
+
llm_options = llm_options_general + llm_options_all
|
256 |
+
# Create the dropdown and set the value to session_state['LLM']
|
257 |
+
st.write("Which LLM is this prompt designed for? This will not restrict its use to a specific LLM, but some prompts will behave differently across models.")
|
258 |
+
st.write("SLTPvA prompts have been validated with all supported LLMs, but perfornce may vary. If you design a prompt to work best with a specific model, then you can indicate the model here.")
|
259 |
+
st.write("For general purpose prompts (like the SLTPvA prompts) just use the 'General Purpose' option.")
|
260 |
+
st.session_state['LLM'] = st.selectbox('Set LLM', llm_options, index=llm_options.index(st.session_state.get('LLM', 'General Purpose')))
|
261 |
+
|
262 |
+
st.write('---')
|
263 |
+
# Instructions Section
|
264 |
+
st.header("Instructions")
|
265 |
+
st.write("These are the general instructions that guide the LLM through the transcription task. We recommend using the default instructions unless you have a specific reason to change them.")
|
266 |
+
|
267 |
+
if 'instructions' not in st.session_state:
|
268 |
+
st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['default_instructions'].strip(), height=350,key=111112)
|
269 |
+
else:
|
270 |
+
st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['instructions'].strip(), height=350,key=111112)
|
271 |
+
|
272 |
+
|
273 |
+
st.write('---')
|
274 |
+
|
275 |
+
# Column Instructions Section
|
276 |
+
st.header("JSON Formatting Instructions")
|
277 |
+
st.write("The following section tells the LLM how we want to structure the JSON dictionary. We do not recommend changing this section because it would likely result in unstable and inconsistent behavior.")
|
278 |
+
if 'json_formatting_instructions' not in st.session_state:
|
279 |
+
st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['default_json_formatting_instructions'],key=111114)
|
280 |
+
else:
|
281 |
+
st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['json_formatting_instructions'],key=111115)
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
st.write('---')
|
289 |
+
col_left, col_right = st.columns([6,4])
|
290 |
+
|
291 |
+
null_value_rules = ''
|
292 |
+
c_name = "EXAMPLE_COLUMN_NAME"
|
293 |
+
c_value = "REPLACE WITH DESCRIPTION"
|
294 |
+
|
295 |
+
with col_left:
|
296 |
+
st.subheader('Add/Edit Columns')
|
297 |
+
st.markdown("The pre-populated fields are REQUIRED for downstream validation steps. They must be in all prompts.")
|
298 |
+
|
299 |
+
# Initialize rules in session state if not already present
|
300 |
+
if 'rules' not in st.session_state or not st.session_state['rules']:
|
301 |
+
for required_col in st.session_state['required_fields']:
|
302 |
+
st.session_state['rules'][required_col] = c_value
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
# Layout for adding a new column name
|
308 |
+
# col_text, col_textbtn = st.columns([8, 2])
|
309 |
+
# with col_text:
|
310 |
+
st.session_state['new_column_name'] = st.text_input("Enter a new column name:")
|
311 |
+
# with col_textbtn:
|
312 |
+
# st.write('##')
|
313 |
+
if st.button("Add New Column") and st.session_state['new_column_name']:
|
314 |
+
if st.session_state['new_column_name'] not in st.session_state['rules']:
|
315 |
+
st.session_state['rules'][st.session_state['new_column_name']] = c_value
|
316 |
+
st.success(f"New column '{st.session_state['new_column_name']}' added. Now you can edit its properties.")
|
317 |
+
st.session_state['new_column_name'] = ''
|
318 |
+
else:
|
319 |
+
st.error("Column name already exists. Please enter a unique column name.")
|
320 |
+
st.session_state['new_column_name'] = ''
|
321 |
+
|
322 |
+
# Get columns excluding the protected "catalogNumber"
|
323 |
+
st.write('#')
|
324 |
+
# required_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
|
325 |
+
editable_columns = [col for col in st.session_state['rules'] if col not in ["catalogNumber"]]
|
326 |
+
removable_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
|
327 |
+
|
328 |
+
st.session_state['current_rule'] = st.selectbox("Select a column to edit:", [""] + editable_columns)
|
329 |
+
# column_name = st.selectbox("Select a column to edit:", editable_columns)
|
330 |
+
|
331 |
+
# Form for input fields
|
332 |
+
with st.form(key='rule_form'):
|
333 |
+
# format_options = ["verbatim transcription", "spell check transcription", "boolean yes no", "boolean 1 0", "integer", "[list]", "yyyy-mm-dd"]
|
334 |
+
# current_rule["format"] = st.selectbox("Format:", format_options, index=format_options.index(current_rule["format"]) if current_rule["format"] else 0)
|
335 |
+
# current_rule["null_value"] = st.text_input("Null value:", value=current_rule["null_value"])
|
336 |
+
if st.session_state['current_rule']:
|
337 |
+
current_rule_description = st.text_area("Description of category:", value=st.session_state['rules'][st.session_state['current_rule']])
|
338 |
+
else:
|
339 |
+
current_rule_description = ''
|
340 |
+
commit_button = st.form_submit_button("Commit Column")
|
341 |
+
|
342 |
+
# Handle commit action
|
343 |
+
if commit_button and st.session_state['current_rule']:
|
344 |
+
# Commit the rules to the session state.
|
345 |
+
st.session_state['rules'][st.session_state['current_rule']] = current_rule_description
|
346 |
+
st.success(f"Column '{st.session_state['current_rule']}' added/updated in rules.")
|
347 |
+
|
348 |
+
# Force the form to reset by clearing the fields from the session state
|
349 |
+
st.session_state.pop('current_rule', None) # Clear the selected column to force reset
|
350 |
+
|
351 |
+
delete_column_name = st.selectbox("Select a column to delete:", [""] + removable_columns)
|
352 |
+
# with del_colbtn:
|
353 |
+
# st.write('##')
|
354 |
+
if st.button("Delete Column") and delete_column_name:
|
355 |
+
del st.session_state['rules'][delete_column_name]
|
356 |
+
st.success(f"Column '{delete_column_name}' removed from rules.")
|
357 |
+
|
358 |
+
with col_right:
|
359 |
+
# Display the current state of the JSON rules
|
360 |
+
st.subheader('Formatted Columns')
|
361 |
+
st.json(st.session_state['rules'])
|
362 |
+
|
363 |
+
st.write('---')
|
364 |
+
|
365 |
+
col_left_mapping, col_right_mapping = st.columns([6,4])
|
366 |
+
with col_left_mapping:
|
367 |
+
st.header("Mapping")
|
368 |
+
st.write("Assign each column name to a single category.")
|
369 |
+
st.session_state['refresh_mapping'] = False
|
370 |
+
|
371 |
+
# Dynamically create a list of all column names that can be assigned
|
372 |
+
# This assumes that the column names are the keys in the dictionary under 'rules'
|
373 |
+
all_column_names = list(st.session_state['rules'].keys())
|
374 |
+
|
375 |
+
categories = ['TAXONOMY', 'GEOGRAPHY', 'LOCALITY', 'COLLECTING', 'MISC']
|
376 |
+
if ('mapping' not in st.session_state) or (st.session_state['mapping'] == {}):
|
377 |
+
st.session_state['mapping'] = {category: [] for category in categories}
|
378 |
+
for category in categories:
|
379 |
+
# Filter out the already assigned columns
|
380 |
+
available_columns = [col for col in all_column_names if col not in st.session_state['assigned_columns'] or col in st.session_state['mapping'].get(category, [])]
|
381 |
+
|
382 |
+
# Ensure the current mapping is a subset of the available options
|
383 |
+
current_mapping = [col for col in st.session_state['mapping'].get(category, []) if col in available_columns]
|
384 |
+
|
385 |
+
# Provide a safe default if the current mapping is empty or contains invalid options
|
386 |
+
safe_default = current_mapping if all(col in available_columns for col in current_mapping) else []
|
387 |
+
|
388 |
+
# Create a multi-select widget for the category with a safe default
|
389 |
+
selected_columns = st.multiselect(
|
390 |
+
f"Select columns for {category}:",
|
391 |
+
available_columns,
|
392 |
+
default=safe_default,
|
393 |
+
key=f"mapping_{category}"
|
394 |
+
)
|
395 |
+
# Update the assigned_columns based on the selections
|
396 |
+
for col in current_mapping:
|
397 |
+
if col not in selected_columns and col in st.session_state['assigned_columns']:
|
398 |
+
st.session_state['assigned_columns'].remove(col)
|
399 |
+
st.session_state['refresh_mapping'] = True
|
400 |
+
|
401 |
+
for col in selected_columns:
|
402 |
+
if col not in st.session_state['assigned_columns']:
|
403 |
+
st.session_state['assigned_columns'].append(col)
|
404 |
+
st.session_state['refresh_mapping'] = True
|
405 |
+
|
406 |
+
# Update the mapping in session state when there's a change
|
407 |
+
st.session_state['mapping'][category] = selected_columns
|
408 |
+
if st.session_state['refresh_mapping']:
|
409 |
+
st.session_state['refresh_mapping'] = False
|
410 |
+
|
411 |
+
# Button to confirm and save the mapping configuration
|
412 |
+
if st.button('Confirm Mapping'):
|
413 |
+
if check_unique_mapping_assignments():
|
414 |
+
# Proceed with further actions since the mapping is confirmed and unique
|
415 |
+
pass
|
416 |
+
|
417 |
+
with col_right_mapping:
|
418 |
+
# Display the current state of the JSON rules
|
419 |
+
st.subheader('Formatted Column Maps')
|
420 |
+
st.json(st.session_state['mapping'])
|
421 |
+
|
422 |
+
|
423 |
+
col_left_save, col_right_save = st.columns([6,4])
|
424 |
+
with col_left_save:
|
425 |
+
# Input for new file name
|
426 |
+
new_filename = st.text_input("Enter filename to save your prompt as a configuration YAML:",placeholder='my_prompt_name')
|
427 |
+
# Button to save the new YAML file
|
428 |
+
if st.button('Save YAML', type='primary'):
|
429 |
+
if new_filename:
|
430 |
+
if check_unique_mapping_assignments():
|
431 |
+
if check_prompt_yaml_filename(new_filename):
|
432 |
+
save_prompt_yaml(new_filename, col_left_save)
|
433 |
+
else:
|
434 |
+
st.error("File name can only contain letters, numbers, underscores, and dashes. Cannot contain spaces.")
|
435 |
+
else:
|
436 |
+
st.error("Mapping contains an error. Make sure that each column is assigned to only ***one*** category.")
|
437 |
+
else:
|
438 |
+
st.error("Please enter a filename.")
|
439 |
+
|
440 |
+
if st.button('Exit'):
|
441 |
+
st.session_state.proceed_to_build_llm_prompt = False
|
442 |
+
st.session_state.proceed_to_main = True
|
443 |
+
st.rerun()
|
444 |
+
|
445 |
+
|
446 |
+
with col_prompt_main_right:
|
447 |
+
if st.session_state['user_clicked_load_prompt_yaml'] is None: # see if user has loaded a yaml to edit
|
448 |
+
st.session_state['show_prompt_name_e'] = f"Prompt Status :arrow_forward: Building prompt from scratch"
|
449 |
+
if st.session_state['prompt_name']:
|
450 |
+
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
|
451 |
+
else:
|
452 |
+
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
|
453 |
+
else:
|
454 |
+
st.session_state['show_prompt_name_e'] = f"Prompt Status: Editing :arrow_forward: {st.session_state['selected_yaml_file']}"
|
455 |
+
if st.session_state['prompt_name']:
|
456 |
+
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
|
457 |
+
else:
|
458 |
+
st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
|
459 |
+
|
460 |
+
st.subheader(f'Full Prompt')
|
461 |
+
st.write(st.session_state['show_prompt_name_e'])
|
462 |
+
st.write(st.session_state['show_prompt_name_w'])
|
463 |
+
st.write("---")
|
464 |
+
st.session_state['prompt_info'] = {
|
465 |
+
'prompt_author': st.session_state['prompt_author'],
|
466 |
+
'prompt_author_institution': st.session_state['prompt_author_institution'],
|
467 |
+
'prompt_name': st.session_state['prompt_name'],
|
468 |
+
'prompt_version': st.session_state['prompt_version'],
|
469 |
+
'prompt_description': st.session_state['prompt_description'],
|
470 |
+
'LLM': st.session_state['LLM'],
|
471 |
+
'instructions': st.session_state['instructions'],
|
472 |
+
'json_formatting_instructions': st.session_state['json_formatting_instructions'],
|
473 |
+
'rules': st.session_state['rules'],
|
474 |
+
'mapping': st.session_state['mapping'],
|
475 |
+
}
|
476 |
+
st.json(st.session_state['prompt_info'])
|
477 |
+
|
478 |
+
build_LLM_prompt_config()
|
pages/report_bugs.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import streamlit.components.v1 as components
|
4 |
+
|
5 |
+
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VV Report Bugs',initial_sidebar_state="collapsed")
|
6 |
+
|
7 |
+
def display_report():
|
8 |
+
c1, c2, c3 = st.columns([4,6,1])
|
9 |
+
with c3:
|
10 |
+
st.page_link('app.py', label="Home", icon="🏠")
|
11 |
+
st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
|
12 |
+
st.page_link(os.path.join("pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
|
13 |
+
|
14 |
+
with c2:
|
15 |
+
st.write('To report a bug or request a new feature please fill out this [Google Form](https://docs.google.com/forms/d/e/1FAIpQLSdtW1z9Q1pGZTo5W9UeCa6PlQanP-b88iNKE6zsusRI78Itsw/viewform?usp=sf_link)')
|
16 |
+
components.iframe(f"https://docs.google.com/forms/d/e/1FAIpQLSdtW1z9Q1pGZTo5W9UeCa6PlQanP-b88iNKE6zsusRI78Itsw/viewform?embedded=true", height=700,scrolling=True,width=640)
|
17 |
+
|
18 |
+
|
19 |
+
display_report()
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
run_VoucherVision.py
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
import streamlit.web.cli as stcli
|
2 |
import os, sys
|
3 |
|
4 |
-
# Insert a file uploader that accepts multiple files at a time:
|
5 |
-
# import streamlit as st
|
6 |
-
# uploaded_files = st.file_uploader("Choose a CSV file", accept_multiple_files=True)
|
7 |
-
# for uploaded_file in uploaded_files:
|
8 |
-
# bytes_data = uploaded_file.read()
|
9 |
-
# st.write("filename:", uploaded_file.name)
|
10 |
-
# st.write(bytes_data)
|
11 |
-
|
12 |
# pip install protobuf==3.20.0
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def resolve_path(path):
|
|
|
1 |
import streamlit.web.cli as stcli
|
2 |
import os, sys
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# pip install protobuf==3.20.0
|
5 |
+
# pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 nope
|
6 |
+
# pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
|
7 |
+
|
8 |
|
9 |
|
10 |
def resolve_path(path):
|
vouchervision/LLM_crew_OpenAI.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from crewai import Agent, Task, Crew, Process
|
3 |
+
from langchain_community.tools import DuckDuckGoSearchRun
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class AIResearchCrew:
|
9 |
+
def __init__(self, openai_api_key, OCR, JSON_rules, search_tool=None, llm=None):
|
10 |
+
# Set the OPENAI API key
|
11 |
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
12 |
+
|
13 |
+
# Initialize the search tool, defaulting to DuckDuckGoSearchRun if not provided
|
14 |
+
self.search_tool = search_tool if search_tool is not None else DuckDuckGoSearchRun()
|
15 |
+
|
16 |
+
# Initialize the LLM (Language Learning Model), if provided
|
17 |
+
self.llm = llm
|
18 |
+
|
19 |
+
# Define the agents
|
20 |
+
self.transcriber = Agent(
|
21 |
+
role='Expert Text Parser',
|
22 |
+
goal='Parse and rearrange unstructured OCR text into a standardized JSON dictionary',
|
23 |
+
backstory="""You work at a museum transcribing specimen labels.
|
24 |
+
Your expertise lies in precisely transcribing text and placing the text into the appropriate category.""",
|
25 |
+
verbose=True,
|
26 |
+
allow_delegation=False
|
27 |
+
# Optionally include llm=self.llm here if an LLM was provided
|
28 |
+
)
|
29 |
+
|
30 |
+
self.spell_check = Agent(
|
31 |
+
role='Spell Checker',
|
32 |
+
goal='Correct any typos in the JSON key values',
|
33 |
+
backstory="""Your job is to look at the JSON key values and use your knowledge to verify spelling. Your corrections should be incorporated into the JSON object that will be passed to the next employee, so return the spell-checked JSON dictionary or the previous JSON dictionary if no changes are required.""",
|
34 |
+
verbose=True,
|
35 |
+
allow_delegation=True,
|
36 |
+
# Optionally include llm=self.llm here if an LLM was provided
|
37 |
+
)
|
38 |
+
|
39 |
+
self.fact_check = Agent(
|
40 |
+
role='Fact Checker',
|
41 |
+
goal='Verify the accuracy of taxonomy and location names',
|
42 |
+
backstory="""Your job is to verify the plant taxonomy and geographic locations contained within the key values are accurate. You can use internet searches to check these fields. Your corrections should be incorporated into a new JSON object that will be passed to the next employee, so return the corrected JSON dictionary or the previous JSON dictionary if no changes are required.""",
|
43 |
+
verbose=True,
|
44 |
+
allow_delegation=True,
|
45 |
+
tools=[self.search_tool]
|
46 |
+
# Optionally include llm=self.llm here if an LLM was provided
|
47 |
+
)
|
48 |
+
|
49 |
+
self.validator = Agent(
|
50 |
+
role='Synthesis',
|
51 |
+
goal='Create a final museum JSON record',
|
52 |
+
backstory="""You must produce a final JSON dictionary only.""",
|
53 |
+
verbose=True,
|
54 |
+
allow_delegation=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
# Define the tasks
|
58 |
+
self.task1 = Task(
|
59 |
+
description=f"Use your knowledge to reformat, transform, and rearrange the unstructured text to fit the following requirements:{JSON_rules}. For null values, use an empty string. This is the unformatted OCR text: {OCR}",
|
60 |
+
agent=self.transcriber
|
61 |
+
)
|
62 |
+
|
63 |
+
self.task2 = Task(
|
64 |
+
description=f"The original text is OCR text, which may contain minor typos. Your job is to check all of the key values and fix any minor typos or spelling mistakes. You should remove any extraneous characters that should not belong in an official museum record.",
|
65 |
+
agent=self.spell_check
|
66 |
+
)
|
67 |
+
|
68 |
+
self.task3 = Task(
|
69 |
+
description="""Use your knowledge or search the internet to verify the information contained within the JSON dictionary.
|
70 |
+
For taxonomy, use the information contained in these keys: order, family, scientificName, scientificNameAuthorship, genus, specificEpithet, infraspecificEpithet.
|
71 |
+
For geography, use the information contained in these keys: country, stateProvince, municipality, decimalLatitude, decimalLongitude.""",
|
72 |
+
agent=self.fact_check
|
73 |
+
)
|
74 |
+
|
75 |
+
self.task4 = Task(
|
76 |
+
description=f"Verify that the JSON dictionary is valid. If not, correct the error. Then print out the final JSON dictionary only without explanations.",
|
77 |
+
agent=self.validator
|
78 |
+
)
|
79 |
+
|
80 |
+
# Create the crew
|
81 |
+
# self.crew = Crew(
|
82 |
+
# agents=[self.transcriber, self.spell_check, self.fact_check, self.validator],
|
83 |
+
# tasks=[self.task1, self.task2, self.task3, self.task4],
|
84 |
+
# verbose=2, # You can set it to 1 or 2 for different logging levels
|
85 |
+
# manager_llm=ChatOpenAI(temperature=0, model="gpt-4-1106-preview"),
|
86 |
+
# process=Process.hierarchical,
|
87 |
+
# )
|
88 |
+
self.crew = Crew(
|
89 |
+
agents=[self.transcriber, self.validator],
|
90 |
+
tasks=[self.task1, self.task4],
|
91 |
+
manager_llm=ChatOpenAI(temperature=0, model="gpt-4-1106-preview"),
|
92 |
+
process=Process.sequential,
|
93 |
+
verbose=2 # You can set it to 1 or 2 for different logging levels
|
94 |
+
)
|
95 |
+
|
96 |
+
def execute_tasks(self):
|
97 |
+
# Kick off the process and return the result
|
98 |
+
result = self.crew.kickoff()
|
99 |
+
print("######################")
|
100 |
+
print(result)
|
101 |
+
return result
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
openai_api_key = ""
|
105 |
+
OCR = "HERBARIUM OF MARYGROVE COLLEGE Name Carex scoparia V. condensa Fernald Locality Interlaken , Ind . Date 7/20/25 No ... ! Gerould Wilhelm & Laura Rericha \" Interlaken , \" was the site for many years of St. Joseph Novitiate , run by the Brothers of the Holy Cross . The buildings were on the west shore of Silver Lake , about 2 miles NE of Rolling Prairie , LaPorte Co. Indiana , ca. 41.688 \u00b0 N , 86.601 \u00b0 W Collector : Sister M. Vincent de Paul McGivney February 1 , 2011 THE UNIVERS Examined for the Flora of the Chicago Region OF 1817 MICH ! Ciscoparia SMVdeP University of Michigan Herbarium 1386297 copyright reserved cm Collector wortet 2010"
|
106 |
+
JSON_rules = """This is the JSON template that includes instructions for each key
|
107 |
+
{'catalogNumber': barcode identifier, at least 6 digits, fewer than 30 digits.,
|
108 |
+
'order': full scientific name of the Order in which the taxon is classified. Order must be capitalized.,
|
109 |
+
'family': full scientific name of the Family in which the taxon is classified. Family must be capitalized.,
|
110 |
+
'scientificName': scientific name of the taxon including Genus, specific epithet, and any lower classifications.,
|
111 |
+
'scientificNameAuthorship': authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.,
|
112 |
+
'genus': taxonomic determination to Genus, Genus must be capitalized.,
|
113 |
+
'subgenus': name of the subgenus.,
|
114 |
+
'specificEpithet': The name of the first or species epithet of the scientificName. Only include the species epithet.,
|
115 |
+
'infraspecificEpithet': lowest or terminal infraspecific epithet of the scientificName.,
|
116 |
+
'identifiedBy': list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector., recordedBy list of names of people, doctors, professors, groups, or organizations.,
|
117 |
+
'recordNumber': identifier given to the specimen at the time it was recorded.,
|
118 |
+
'verbatimEventDate': The verbatim original representation of the date and time information for when the specimen was collected.,
|
119 |
+
'eventDate': collection date formatted as year-month-day YYYY-MM-DD., habitat habitat.,
|
120 |
+
'occurrenceRemarks': all descriptive text in the OCR rearranged into sensible sentences or sentence fragments.,
|
121 |
+
'country': country or major administrative unit.,
|
122 |
+
'stateProvince': state, province, canton, department, region, etc., county county, shire, department, parish etc.,
|
123 |
+
'municipality': city, municipality, etc., locality description of geographic information aiding in pinpointing the exact origin or location of the specimen.,
|
124 |
+
'degreeOfEstablishment': cultivated plants are intentionally grown by humans. Use either - unknown or cultivated.,
|
125 |
+
'decimalLatitude': latitude decimal coordinate.,
|
126 |
+
'decimalLongitude': longitude decimal coordinate., verbatimCoordinates verbatim location coordinates.,
|
127 |
+
'minimumElevationInMeters': minimum elevation or altitude in meters.,
|
128 |
+
'maximumElevationInMeters': maximum elevation or altitude in meters.}"""
|
129 |
+
ai_research_crew = AIResearchCrew(openai_api_key, OCR, JSON_rules)
|
130 |
+
result = ai_research_crew.execute_tasks()
|
vouchervision/LLM_local_cpu_MistralAI.py
CHANGED
@@ -56,8 +56,6 @@ class LocalCPUMistralHandler:
|
|
56 |
raise f"Unsupported GGUF model name"
|
57 |
|
58 |
# self.model_id = f"mistralai/{self.model_name}"
|
59 |
-
self.gpu_usage = {'max_load': 0, 'max_memory_usage': 0, 'monitoring': True}
|
60 |
-
|
61 |
self.starting_temp = float(self.STARTING_TEMP)
|
62 |
self.temp_increment = float(0.2)
|
63 |
self.adjust_temp = self.starting_temp
|
|
|
56 |
raise f"Unsupported GGUF model name"
|
57 |
|
58 |
# self.model_id = f"mistralai/{self.model_name}"
|
|
|
|
|
59 |
self.starting_temp = float(self.STARTING_TEMP)
|
60 |
self.temp_increment = float(0.2)
|
61 |
self.adjust_temp = self.starting_temp
|
vouchervision/OCR_CRAFT.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import Craft class
|
2 |
+
from craft_text_detector import read_image, load_craftnet_model, load_refinenet_model, get_prediction, export_detected_regions, export_extra_results, empty_cuda_cache
|
3 |
+
|
4 |
+
def main2():
|
5 |
+
# import craft functions
|
6 |
+
|
7 |
+
|
8 |
+
# set image path and export folder directory
|
9 |
+
# image = 'D:/Dropbox/SLTP/benchmark_datasets/SLTP_B50_MICH_Angiospermae2/img/MICH_7375774_Polygonaceae_Persicaria_.jpg' # can be filepath, PIL image or numpy array
|
10 |
+
# image = 'C:/Users/Will/Downloads/test_2024_02_07__14-59-52/Original_Images/SJRw 00891 - 01141__10001.jpg'
|
11 |
+
image = 'D:/Dropbox/VoucherVision/demo/demo_images/MICH_16205594_Poaceae_Jouvea_pilosa.jpg'
|
12 |
+
output_dir = 'D:/D_Desktop/test_out_CRAFT'
|
13 |
+
|
14 |
+
# read image
|
15 |
+
image = read_image(image)
|
16 |
+
|
17 |
+
# load models
|
18 |
+
refine_net = load_refinenet_model(cuda=True)
|
19 |
+
craft_net = load_craftnet_model(weight_path='D:/Dropbox/VoucherVision/vouchervision/craft/craft_mlt_25k.pth', cuda=True)
|
20 |
+
|
21 |
+
# perform prediction
|
22 |
+
prediction_result = get_prediction(
|
23 |
+
image=image,
|
24 |
+
craft_net=craft_net,
|
25 |
+
refine_net=refine_net,
|
26 |
+
text_threshold=0.4,
|
27 |
+
link_threshold=0.7,
|
28 |
+
low_text=0.4,
|
29 |
+
cuda=True,
|
30 |
+
long_size=1280
|
31 |
+
)
|
32 |
+
|
33 |
+
# export detected text regions
|
34 |
+
exported_file_paths = export_detected_regions(
|
35 |
+
image=image,
|
36 |
+
regions=prediction_result["boxes"],
|
37 |
+
output_dir=output_dir,
|
38 |
+
rectify=True
|
39 |
+
)
|
40 |
+
|
41 |
+
# export heatmap, detection points, box visualization
|
42 |
+
export_extra_results(
|
43 |
+
image=image,
|
44 |
+
regions=prediction_result["boxes"],
|
45 |
+
heatmaps=prediction_result["heatmaps"],
|
46 |
+
output_dir=output_dir
|
47 |
+
)
|
48 |
+
|
49 |
+
# unload models from gpu
|
50 |
+
empty_cuda_cache()
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
# main()
|
55 |
+
main2()
|
vouchervision/OCR_google_cloud_vision.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, io, sys, inspect, statistics, json
|
2 |
from statistics import mean
|
3 |
# from google.cloud import vision, storage
|
4 |
from google.cloud import vision
|
@@ -8,10 +8,16 @@ import colorsys
|
|
8 |
from tqdm import tqdm
|
9 |
from google.oauth2 import service_account
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
'''
|
@@ -23,19 +29,31 @@ from google.oauth2 import service_account
|
|
23 |
archivePrefix={arXiv},
|
24 |
primaryClass={cs.CL}
|
25 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
'''
|
27 |
|
28 |
-
class
|
29 |
|
30 |
BBOX_COLOR = "black"
|
31 |
|
32 |
-
def __init__(self, is_hf, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
|
33 |
self.is_hf = is_hf
|
|
|
|
|
|
|
34 |
|
35 |
self.path = path
|
36 |
self.cfg = cfg
|
37 |
self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
|
38 |
self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
|
|
|
|
|
39 |
|
40 |
# Initialize TrOCR components
|
41 |
self.trOCR_model_version = trOCR_model_version
|
@@ -70,6 +88,9 @@ class OCRGoogle:
|
|
70 |
self.trOCR_confidences = None
|
71 |
self.trOCR_characters = None
|
72 |
self.set_client()
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
def set_client(self):
|
@@ -86,6 +107,131 @@ class OCRGoogle:
|
|
86 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
87 |
return credentials
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
91 |
CONFIDENCES = 0.80
|
@@ -93,33 +239,36 @@ class OCRGoogle:
|
|
93 |
|
94 |
self.OCR_JSON_to_file = {}
|
95 |
|
|
|
96 |
if not do_use_trOCR:
|
97 |
-
if self.OCR_option
|
98 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
99 |
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
|
100 |
-
|
|
|
101 |
|
102 |
-
if self.OCR_option
|
103 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
104 |
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
|
105 |
-
|
106 |
-
|
107 |
-
if self.OCR_option in ['both',]:
|
108 |
-
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
|
109 |
-
return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
|
110 |
|
|
|
|
|
|
|
|
|
111 |
else:
|
112 |
logger.info(f'Supplementing with trOCR')
|
113 |
|
114 |
self.trOCR_texts = []
|
115 |
original_image = Image.open(self.path).convert("RGB")
|
116 |
|
117 |
-
if self.OCR_option
|
118 |
available_bounds = self.normal_bounds_word
|
119 |
-
elif self.OCR_option
|
120 |
-
available_bounds = self.hand_bounds_word
|
121 |
-
elif self.OCR_option in ['both',]:
|
122 |
available_bounds = self.hand_bounds_word
|
|
|
|
|
123 |
else:
|
124 |
raise
|
125 |
|
@@ -127,9 +276,13 @@ class OCRGoogle:
|
|
127 |
characters = []
|
128 |
height = []
|
129 |
confidences = []
|
|
|
|
|
130 |
for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
|
|
|
|
|
|
|
131 |
vertices = bound["vertices"]
|
132 |
-
|
133 |
|
134 |
left = min([v["x"] for v in vertices])
|
135 |
top = min([v["y"] for v in vertices])
|
@@ -177,24 +330,31 @@ class OCRGoogle:
|
|
177 |
self.trOCR_confidences = confidences
|
178 |
self.trOCR_characters = characters
|
179 |
|
180 |
-
if self.OCR_option
|
181 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
182 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
183 |
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
184 |
-
|
185 |
-
|
|
|
186 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
187 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
188 |
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
@staticmethod
|
200 |
def confidence_to_color(confidence):
|
@@ -220,7 +380,7 @@ class OCRGoogle:
|
|
220 |
if option == 'trOCR':
|
221 |
color = (0, 170, 255)
|
222 |
else:
|
223 |
-
color =
|
224 |
position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
|
225 |
draw.text(position, character, fill=color, font=font)
|
226 |
|
@@ -258,13 +418,13 @@ class OCRGoogle:
|
|
258 |
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
259 |
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
260 |
],
|
261 |
-
outline=
|
262 |
width=line_width_thin
|
263 |
)
|
264 |
|
265 |
# Draw a line segment at the bottom of each handwritten character
|
266 |
for bound, confidence in zip(bounds, confidences):
|
267 |
-
color =
|
268 |
# Use the bottom two vertices of the bounding box for the line
|
269 |
bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
|
270 |
bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
|
@@ -386,6 +546,7 @@ class OCRGoogle:
|
|
386 |
self.normal_height = height_flat
|
387 |
self.normal_confidences = confidences
|
388 |
self.normal_characters = characters
|
|
|
389 |
|
390 |
|
391 |
def detect_handwritten_ocr(self):
|
@@ -503,56 +664,112 @@ class OCRGoogle:
|
|
503 |
self.hand_height = height_flat
|
504 |
self.hand_confidences = confidences
|
505 |
self.hand_characters = characters
|
|
|
506 |
|
507 |
|
508 |
def process_image(self, do_create_OCR_helper_image, logger):
|
509 |
-
|
510 |
-
|
511 |
-
if self.OCR_option
|
512 |
-
self.
|
513 |
-
|
514 |
-
self.
|
515 |
-
self.
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
self.image = Image.open(self.path)
|
524 |
|
525 |
-
if
|
526 |
image_with_boxes_normal = self.draw_boxes('normal')
|
527 |
text_image_normal = self.render_text_on_black_image('normal')
|
528 |
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
|
529 |
|
530 |
-
if
|
531 |
image_with_boxes_hand = self.draw_boxes('hand')
|
532 |
text_image_hand = self.render_text_on_black_image('hand')
|
533 |
self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
|
534 |
|
535 |
if self.do_use_trOCR:
|
536 |
-
text_image_trOCR = self.render_text_on_black_image('trOCR')
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
|
538 |
### Merge final overlay image
|
539 |
### [original, normal bboxes, normal text]
|
540 |
-
if self.OCR_option
|
541 |
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
|
542 |
### [original, hand bboxes, hand text]
|
543 |
-
elif self.OCR_option
|
544 |
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
|
545 |
### [original, normal bboxes, normal text, hand bboxes, hand text]
|
546 |
else:
|
547 |
self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
|
548 |
|
549 |
if self.do_use_trOCR:
|
550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
|
552 |
else:
|
553 |
self.merged_image_normal = None
|
554 |
self.merged_image_hand = None
|
555 |
self.overlay_image = Image.open(self.path)
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
557 |
|
558 |
'''
|
|
|
1 |
+
import os, io, sys, inspect, statistics, json, cv2
|
2 |
from statistics import mean
|
3 |
# from google.cloud import vision, storage
|
4 |
from google.cloud import vision
|
|
|
8 |
from tqdm import tqdm
|
9 |
from google.oauth2 import service_account
|
10 |
|
11 |
+
### LLaVA should only be installed if the user will actually use it.
|
12 |
+
### It requires the most recent pytorch/Python and can mess with older systems
|
13 |
+
try:
|
14 |
+
from craft_text_detector import read_image, load_craftnet_model, load_refinenet_model, get_prediction, export_detected_regions, export_extra_results, empty_cuda_cache
|
15 |
+
except:
|
16 |
+
pass
|
17 |
+
try:
|
18 |
+
from OCR_llava import OCRllava
|
19 |
+
except:
|
20 |
+
pass
|
21 |
|
22 |
|
23 |
'''
|
|
|
29 |
archivePrefix={arXiv},
|
30 |
primaryClass={cs.CL}
|
31 |
}
|
32 |
+
@inproceedings{baek2019character,
|
33 |
+
title={Character Region Awareness for Text Detection},
|
34 |
+
author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk},
|
35 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
36 |
+
pages={9365--9374},
|
37 |
+
year={2019}
|
38 |
+
}
|
39 |
'''
|
40 |
|
41 |
+
class OCREngine:
|
42 |
|
43 |
BBOX_COLOR = "black"
|
44 |
|
45 |
+
def __init__(self, logger, json_report, dir_home, is_hf, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
|
46 |
self.is_hf = is_hf
|
47 |
+
self.logger = logger
|
48 |
+
|
49 |
+
self.json_report = json_report
|
50 |
|
51 |
self.path = path
|
52 |
self.cfg = cfg
|
53 |
self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
|
54 |
self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
|
55 |
+
self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
|
56 |
+
self.dir_home = dir_home
|
57 |
|
58 |
# Initialize TrOCR components
|
59 |
self.trOCR_model_version = trOCR_model_version
|
|
|
88 |
self.trOCR_confidences = None
|
89 |
self.trOCR_characters = None
|
90 |
self.set_client()
|
91 |
+
self.init_craft()
|
92 |
+
if 'LLaVA' in self.OCR_option:
|
93 |
+
self.init_llava()
|
94 |
|
95 |
|
96 |
def set_client(self):
|
|
|
107 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
108 |
return credentials
|
109 |
|
110 |
+
def init_craft(self):
|
111 |
+
if 'CRAFT' in self.OCR_option:
|
112 |
+
try:
|
113 |
+
self.refine_net = load_refinenet_model(cuda=True)
|
114 |
+
self.use_cuda = True
|
115 |
+
except:
|
116 |
+
self.refine_net = load_refinenet_model(cuda=False)
|
117 |
+
self.use_cuda = False
|
118 |
+
|
119 |
+
if self.use_cuda:
|
120 |
+
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=True)
|
121 |
+
else:
|
122 |
+
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
|
123 |
+
|
124 |
+
def init_llava(self):
|
125 |
+
|
126 |
+
self.llava_prompt = """I need you to transcribe all of the text in this image.
|
127 |
+
Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
|
128 |
+
|
129 |
+
self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
|
130 |
+
self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
|
131 |
+
|
132 |
+
self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
|
133 |
+
|
134 |
+
if self.model_quant == '4bit':
|
135 |
+
use_4bit = True
|
136 |
+
elif self.model_quant == 'full':
|
137 |
+
use_4bit = False
|
138 |
+
else:
|
139 |
+
self.logger.info(f"Provided model quantization invlid. Using 4bit.")
|
140 |
+
use_4bit = True
|
141 |
+
|
142 |
+
self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
|
143 |
+
|
144 |
+
|
145 |
+
def detect_text_craft(self):
|
146 |
+
# Perform prediction using CRAFT
|
147 |
+
image = read_image(self.path)
|
148 |
+
|
149 |
+
link_threshold = 0.85
|
150 |
+
text_threshold = 0.4
|
151 |
+
low_text = 0.4
|
152 |
+
|
153 |
+
if self.use_cuda:
|
154 |
+
self.prediction_result = get_prediction(
|
155 |
+
image=image,
|
156 |
+
craft_net=self.craft_net,
|
157 |
+
refine_net=self.refine_net,
|
158 |
+
text_threshold=text_threshold,
|
159 |
+
link_threshold=link_threshold,
|
160 |
+
low_text=low_text,
|
161 |
+
cuda=True,
|
162 |
+
long_size=1280
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
self.prediction_result = get_prediction(
|
166 |
+
image=image,
|
167 |
+
craft_net=self.craft_net,
|
168 |
+
refine_net=self.refine_net,
|
169 |
+
text_threshold=text_threshold,
|
170 |
+
link_threshold=link_threshold,
|
171 |
+
low_text=low_text,
|
172 |
+
cuda=False,
|
173 |
+
long_size=1280
|
174 |
+
)
|
175 |
+
|
176 |
+
# Initialize metadata structures
|
177 |
+
bounds = []
|
178 |
+
bounds_word = [] # CRAFT gives bounds for text regions, not individual words
|
179 |
+
text_to_box_mapping = []
|
180 |
+
bounds_flat = []
|
181 |
+
height_flat = []
|
182 |
+
confidences = [] # CRAFT does not provide confidences per character, so this might be uniformly set or estimated
|
183 |
+
characters = [] # Simulating as CRAFT doesn't provide character-level details
|
184 |
+
organized_text = ""
|
185 |
+
|
186 |
+
total_b = len(self.prediction_result["boxes"])
|
187 |
+
i=0
|
188 |
+
# Process each detected text region
|
189 |
+
for box in self.prediction_result["boxes"]:
|
190 |
+
i+=1
|
191 |
+
self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
|
192 |
+
|
193 |
+
vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
|
194 |
+
|
195 |
+
# Simulate a mapping for the whole detected region as a word
|
196 |
+
text_to_box_mapping.append({
|
197 |
+
"vertices": vertices,
|
198 |
+
"text": "detected_text" # Placeholder, as CRAFT does not provide the text content directly
|
199 |
+
})
|
200 |
+
|
201 |
+
# Assuming each box is a word for the sake of this example
|
202 |
+
bounds_word.append({"vertices": vertices})
|
203 |
+
|
204 |
+
# For simplicity, we're not dividing text regions into characters as CRAFT doesn't provide this
|
205 |
+
# Instead, we create a single large 'character' per detected region
|
206 |
+
bounds.append({"vertices": vertices})
|
207 |
+
|
208 |
+
# Simulate flat bounds and height for each detected region
|
209 |
+
x_positions = [vertex["x"] for vertex in vertices]
|
210 |
+
y_positions = [vertex["y"] for vertex in vertices]
|
211 |
+
min_x, max_x = min(x_positions), max(x_positions)
|
212 |
+
min_y, max_y = min(y_positions), max(y_positions)
|
213 |
+
avg_height = max_y - min_y
|
214 |
+
height_flat.append(avg_height)
|
215 |
+
|
216 |
+
# Assuming uniform confidence for all detected regions
|
217 |
+
confidences.append(1.0) # Placeholder confidence
|
218 |
+
|
219 |
+
# Adding dummy character for each box
|
220 |
+
characters.append("X") # Placeholder character
|
221 |
+
|
222 |
+
# Organize text as a single string (assuming each box is a word)
|
223 |
+
# organized_text += "detected_text " # Placeholder text
|
224 |
+
|
225 |
+
# Update class attributes with processed data
|
226 |
+
self.normal_bounds = bounds
|
227 |
+
self.normal_bounds_word = bounds_word
|
228 |
+
self.normal_text_to_box_mapping = text_to_box_mapping
|
229 |
+
self.normal_bounds_flat = bounds_flat # This would be similar to bounds if not processing characters individually
|
230 |
+
self.normal_height = height_flat
|
231 |
+
self.normal_confidences = confidences
|
232 |
+
self.normal_characters = characters
|
233 |
+
self.normal_organized_text = organized_text.strip()
|
234 |
+
|
235 |
|
236 |
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
237 |
CONFIDENCES = 0.80
|
|
|
239 |
|
240 |
self.OCR_JSON_to_file = {}
|
241 |
|
242 |
+
ocr_parts = ''
|
243 |
if not do_use_trOCR:
|
244 |
+
if 'normal' in self.OCR_option:
|
245 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
246 |
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
|
247 |
+
# ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
|
248 |
+
ocr_parts = self.normal_organized_text
|
249 |
|
250 |
+
if 'hand' in self.OCR_option:
|
251 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
252 |
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
|
253 |
+
# ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
|
254 |
+
ocr_parts = self.hand_organized_text
|
|
|
|
|
|
|
255 |
|
256 |
+
# if self.OCR_option in ['both',]:
|
257 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
|
258 |
+
# return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
|
259 |
+
return ocr_parts
|
260 |
else:
|
261 |
logger.info(f'Supplementing with trOCR')
|
262 |
|
263 |
self.trOCR_texts = []
|
264 |
original_image = Image.open(self.path).convert("RGB")
|
265 |
|
266 |
+
if 'normal' in self.OCR_option or 'CRAFT' in self.OCR_option:
|
267 |
available_bounds = self.normal_bounds_word
|
268 |
+
elif 'hand' in self.OCR_option:
|
|
|
|
|
269 |
available_bounds = self.hand_bounds_word
|
270 |
+
# elif self.OCR_option in ['both',]:
|
271 |
+
# available_bounds = self.hand_bounds_word
|
272 |
else:
|
273 |
raise
|
274 |
|
|
|
276 |
characters = []
|
277 |
height = []
|
278 |
confidences = []
|
279 |
+
total_b = len(available_bounds)
|
280 |
+
i=0
|
281 |
for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
|
282 |
+
i+=1
|
283 |
+
self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
|
284 |
+
|
285 |
vertices = bound["vertices"]
|
|
|
286 |
|
287 |
left = min([v["x"] for v in vertices])
|
288 |
top = min([v["y"] for v in vertices])
|
|
|
330 |
self.trOCR_confidences = confidences
|
331 |
self.trOCR_characters = characters
|
332 |
|
333 |
+
if 'normal' in self.OCR_option:
|
334 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
335 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
336 |
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
337 |
+
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
338 |
+
ocr_parts = self.trOCR_texts
|
339 |
+
if 'hand' in self.OCR_option:
|
340 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
341 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
342 |
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
343 |
+
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
344 |
+
ocr_parts = self.trOCR_texts
|
345 |
+
# if self.OCR_option in ['both',]:
|
346 |
+
# self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
347 |
+
# self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
348 |
+
# self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
349 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
350 |
+
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
351 |
+
if 'CRAFT' in self.OCR_option:
|
352 |
+
# self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
353 |
+
self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
|
354 |
+
logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
|
355 |
+
# ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
|
356 |
+
ocr_parts = self.trOCR_texts
|
357 |
+
return ocr_parts
|
358 |
|
359 |
@staticmethod
|
360 |
def confidence_to_color(confidence):
|
|
|
380 |
if option == 'trOCR':
|
381 |
color = (0, 170, 255)
|
382 |
else:
|
383 |
+
color = OCREngine.confidence_to_color(confidence)
|
384 |
position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
|
385 |
draw.text(position, character, fill=color, font=font)
|
386 |
|
|
|
418 |
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
419 |
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
420 |
],
|
421 |
+
outline=OCREngine.BBOX_COLOR,
|
422 |
width=line_width_thin
|
423 |
)
|
424 |
|
425 |
# Draw a line segment at the bottom of each handwritten character
|
426 |
for bound, confidence in zip(bounds, confidences):
|
427 |
+
color = OCREngine.confidence_to_color(confidence)
|
428 |
# Use the bottom two vertices of the bounding box for the line
|
429 |
bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
|
430 |
bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
|
|
|
546 |
self.normal_height = height_flat
|
547 |
self.normal_confidences = confidences
|
548 |
self.normal_characters = characters
|
549 |
+
return self.normal_cleaned_text
|
550 |
|
551 |
|
552 |
def detect_handwritten_ocr(self):
|
|
|
664 |
self.hand_height = height_flat
|
665 |
self.hand_confidences = confidences
|
666 |
self.hand_characters = characters
|
667 |
+
return self.hand_cleaned_text
|
668 |
|
669 |
|
670 |
def process_image(self, do_create_OCR_helper_image, logger):
|
671 |
+
# Can stack options, so solitary if statements
|
672 |
+
self.OCR = 'OCR:\n'
|
673 |
+
if 'CRAFT' in self.OCR_option:
|
674 |
+
self.do_use_trOCR = True
|
675 |
+
self.detect_text_craft()
|
676 |
+
### Optionally add trOCR to the self.OCR for additional context
|
677 |
+
if self.double_OCR:
|
678 |
+
part_OCR = "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
679 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
680 |
+
else:
|
681 |
+
self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
682 |
+
logger.info(f"CRAFT trOCR:\n{self.OCR}")
|
683 |
+
|
684 |
+
if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
|
685 |
+
self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
|
686 |
+
|
687 |
+
image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.llava_prompt)
|
688 |
+
self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
|
689 |
+
|
690 |
+
try:
|
691 |
+
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
692 |
+
except:
|
693 |
+
self.OCR_JSON_to_file = {}
|
694 |
+
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
695 |
+
|
696 |
+
if self.double_OCR:
|
697 |
+
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
|
698 |
+
else:
|
699 |
+
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
|
700 |
+
logger.info(f"LLaVA OCR:\n{self.OCR}")
|
701 |
+
|
702 |
+
if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
|
703 |
+
if 'normal' in self.OCR_option:
|
704 |
+
self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
|
705 |
+
if 'hand' in self.OCR_option:
|
706 |
+
self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
|
707 |
+
# if self.OCR_option not in ['normal', 'hand', 'both']:
|
708 |
+
# self.OCR_option = 'both'
|
709 |
+
# self.detect_text()
|
710 |
+
# self.detect_handwritten_ocr()
|
711 |
+
|
712 |
+
### Optionally add trOCR to the self.OCR for additional context
|
713 |
+
if self.double_OCR:
|
714 |
+
part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
715 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
716 |
+
else:
|
717 |
+
self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
718 |
+
logger.info(f"OCR:\n{self.OCR}")
|
719 |
+
|
720 |
+
if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
|
721 |
self.image = Image.open(self.path)
|
722 |
|
723 |
+
if 'normal' in self.OCR_option:
|
724 |
image_with_boxes_normal = self.draw_boxes('normal')
|
725 |
text_image_normal = self.render_text_on_black_image('normal')
|
726 |
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
|
727 |
|
728 |
+
if 'hand' in self.OCR_option:
|
729 |
image_with_boxes_hand = self.draw_boxes('hand')
|
730 |
text_image_hand = self.render_text_on_black_image('hand')
|
731 |
self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
|
732 |
|
733 |
if self.do_use_trOCR:
|
734 |
+
text_image_trOCR = self.render_text_on_black_image('trOCR')
|
735 |
+
|
736 |
+
if 'CRAFT' in self.OCR_option:
|
737 |
+
image_with_boxes_normal = self.draw_boxes('normal')
|
738 |
+
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
|
739 |
+
|
740 |
+
|
741 |
|
742 |
### Merge final overlay image
|
743 |
### [original, normal bboxes, normal text]
|
744 |
+
if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
|
745 |
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
|
746 |
### [original, hand bboxes, hand text]
|
747 |
+
elif 'hand' in self.OCR_option:
|
748 |
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
|
749 |
### [original, normal bboxes, normal text, hand bboxes, hand text]
|
750 |
else:
|
751 |
self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
|
752 |
|
753 |
if self.do_use_trOCR:
|
754 |
+
if 'CRAFT' in self.OCR_option:
|
755 |
+
heat_map_text = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["text_score_heatmap"], cv2.COLOR_BGR2RGB))
|
756 |
+
heat_map_link = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["link_score_heatmap"], cv2.COLOR_BGR2RGB))
|
757 |
+
self.overlay_image = self.merge_images(self.overlay_image, heat_map_text)
|
758 |
+
self.overlay_image = self.merge_images(self.overlay_image, heat_map_link)
|
759 |
+
|
760 |
+
else:
|
761 |
+
self.overlay_image = self.merge_images(self.overlay_image, text_image_trOCR)
|
762 |
|
763 |
else:
|
764 |
self.merged_image_normal = None
|
765 |
self.merged_image_hand = None
|
766 |
self.overlay_image = Image.open(self.path)
|
767 |
+
|
768 |
+
try:
|
769 |
+
empty_cuda_cache()
|
770 |
+
except:
|
771 |
+
pass
|
772 |
+
|
773 |
|
774 |
|
775 |
'''
|
vouchervision/OCR_llava.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, re, logging
|
2 |
+
import requests
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, TextStreamer
|
7 |
+
|
8 |
+
from langchain.prompts import PromptTemplate
|
9 |
+
from langchain_core.output_parsers import JsonOutputParser
|
10 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
11 |
+
|
12 |
+
from LLaVA.llava.model import LlavaLlamaForCausalLM
|
13 |
+
from LLaVA.llava.model.builder import load_pretrained_model
|
14 |
+
from LLaVA.llava.conversation import conv_templates, SeparatorStyle
|
15 |
+
from LLaVA.llava.utils import disable_torch_init
|
16 |
+
from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
|
17 |
+
from LLaVA.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
|
18 |
+
|
19 |
+
from utils_LLM import SystemLoadMonitor
|
20 |
+
|
21 |
+
'''
|
22 |
+
Performance expectations system:
|
23 |
+
GPUs:
|
24 |
+
2x RTX6000 Ada
|
25 |
+
CPU:
|
26 |
+
AMD Ryzen threadripper pro 5975wx 32-cores x64 threads
|
27 |
+
RAM:
|
28 |
+
512 GB
|
29 |
+
OS:
|
30 |
+
Ubuntu 22.04.3 LTS
|
31 |
+
|
32 |
+
LLaVA Models:
|
33 |
+
"liuhaotian/llava-v1.6-mistral-7b" --- Model is 20 GB in size --- Mistral-7B
|
34 |
+
--- Full
|
35 |
+
--- Inference time ~6 sec
|
36 |
+
--- VRAM ~18 GB
|
37 |
+
|
38 |
+
--- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too
|
39 |
+
--- Inference time ~37 sec
|
40 |
+
--- VRAM ~18 GB
|
41 |
+
|
42 |
+
--- 4bit
|
43 |
+
--- Inference time ~15 sec
|
44 |
+
--- VRAM ~9 GB
|
45 |
+
|
46 |
+
|
47 |
+
"liuhaotian/llava-v1.6-34b" --- Model is 100 GB in size --- Hermes-Yi-34B
|
48 |
+
--- Full
|
49 |
+
--- Inference time ~21 sec
|
50 |
+
--- VRAM ~70 GB
|
51 |
+
|
52 |
+
--- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too
|
53 |
+
--- Inference time ~52 sec
|
54 |
+
--- VRAM ~42 GB
|
55 |
+
|
56 |
+
--- 4bit
|
57 |
+
--- Inference time ~23 sec
|
58 |
+
--- VRAM ~25GB
|
59 |
+
|
60 |
+
|
61 |
+
"liuhaotian/llava-v1.6-vicuna-13b" --- Model is 30 GB in size --- Vicuna-13B
|
62 |
+
--- Full
|
63 |
+
--- Inference time ~8 sec
|
64 |
+
--- VRAM ~33 GB
|
65 |
+
|
66 |
+
--- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too, has lots of ALL CAPS and mistakes
|
67 |
+
--- Inference time ~32 sec
|
68 |
+
--- VRAM ~23 GB
|
69 |
+
|
70 |
+
--- 4bit
|
71 |
+
--- Inference time ~12 sec
|
72 |
+
--- VRAM ~15 GB
|
73 |
+
|
74 |
+
|
75 |
+
"liuhaotian/llava-v1.6-vicuna-7b" --- Model is 15 GB in size --- Vicuna-7B
|
76 |
+
--- Full
|
77 |
+
--- Inference time ~7 sec
|
78 |
+
--- VRAM ~20 GB
|
79 |
+
|
80 |
+
--- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too
|
81 |
+
--- Inference time ~27 sec
|
82 |
+
--- VRAM ~14 GB
|
83 |
+
|
84 |
+
--- 4bit
|
85 |
+
--- Inference time ~10 sec
|
86 |
+
--- VRAM ~10 GB
|
87 |
+
|
88 |
+
|
89 |
+
'''
|
90 |
+
|
91 |
+
# OCR_Llava = OCRLlava()
|
92 |
+
# image, caption = OCR_Llava.transcribe_image("path/to/image.jpg", "Describe this image.")
|
93 |
+
# print(caption)
|
94 |
+
|
95 |
+
# Define the desired data structure for the transcription.
|
96 |
+
class Transcription(BaseModel):
|
97 |
+
Transcription: str = Field(description="The transcription of all text in the image.")
|
98 |
+
|
99 |
+
class OCRllava:
|
100 |
+
def __init__(self, logger, model_path="liuhaotian/llava-v1.6-34b",load_in_4bit=False, load_in_8bit=False):
|
101 |
+
self.monitor = SystemLoadMonitor(logger)
|
102 |
+
|
103 |
+
# self.model_path = "liuhaotian/llava-v1.6-mistral-7b"
|
104 |
+
# self.model_path = "liuhaotian/llava-v1.6-34b"
|
105 |
+
# self.model_path = "liuhaotian/llava-v1.6-vicuna-13b"
|
106 |
+
|
107 |
+
self.model_path = model_path
|
108 |
+
|
109 |
+
# kwargs = {"device_map": "auto", "load_in_4bit": load_in_4bit, "quantization_config": BitsAndBytesConfig(
|
110 |
+
# load_in_4bit=load_in_4bit,
|
111 |
+
# bnb_4bit_compute_dtype=torch.float16,
|
112 |
+
# bnb_4bit_use_double_quant=load_in_4bit,
|
113 |
+
# bnb_4bit_quant_type='nf4'
|
114 |
+
# )}
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
if "llama-2" in self.model_path.lower(): # this is borrowed from def eval_model(args): in run_llava.py
|
119 |
+
self.conv_mode = "llava_llama_2"
|
120 |
+
elif "mistral" in self.model_path.lower():
|
121 |
+
self.conv_mode = "mistral_instruct"
|
122 |
+
elif "v1.6-34b" in self.model_path.lower():
|
123 |
+
self.conv_mode = "chatml_direct"
|
124 |
+
elif "v1" in self.model_path.lower():
|
125 |
+
self.conv_mode = "llava_v1"
|
126 |
+
elif "mpt" in self.model_path.lower():
|
127 |
+
self.conv_mode = "mpt"
|
128 |
+
else:
|
129 |
+
self.conv_mode = "llava_v0"
|
130 |
+
|
131 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(self.model_path, None,
|
132 |
+
model_name = get_model_name_from_path(self.model_path),
|
133 |
+
load_8bit=load_in_8bit, load_4bit=load_in_4bit)
|
134 |
+
|
135 |
+
# self.model = LlavaLlamaForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **kwargs)
|
136 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
|
137 |
+
# self.vision_tower = self.model.get_vision_tower()
|
138 |
+
# if not self.vision_tower.is_loaded:
|
139 |
+
# self.vision_tower.load_model()
|
140 |
+
# self.vision_tower.to(device='cuda')
|
141 |
+
# self.image_processor = self.vision_tower.image_processor
|
142 |
+
self.parser = JsonOutputParser(pydantic_object=Transcription)
|
143 |
+
|
144 |
+
def image_parser(self):
|
145 |
+
sep = ","
|
146 |
+
out = self.image_file.split(sep)
|
147 |
+
return out
|
148 |
+
|
149 |
+
def load_image(self, image_file):
|
150 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
151 |
+
response = requests.get(image_file)
|
152 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
153 |
+
else:
|
154 |
+
image = Image.open(image_file).convert("RGB")
|
155 |
+
return image
|
156 |
+
|
157 |
+
def load_images(self, image_files):
|
158 |
+
out = []
|
159 |
+
for image_file in image_files:
|
160 |
+
image = self.load_image(image_file)
|
161 |
+
out.append(image)
|
162 |
+
return out
|
163 |
+
|
164 |
+
def combine_json_values(self, data, separator=" "):
|
165 |
+
"""
|
166 |
+
Recursively traverses through a JSON-like dictionary or list,
|
167 |
+
combining all the values into a single string with a given separator.
|
168 |
+
|
169 |
+
:return: A single string containing all values from the input.
|
170 |
+
"""
|
171 |
+
# Base case for strings, directly return the string
|
172 |
+
if isinstance(data, str):
|
173 |
+
return data
|
174 |
+
|
175 |
+
# If the data is a dictionary, iterate through its values
|
176 |
+
elif isinstance(data, dict):
|
177 |
+
combined_string = separator.join(self.combine_json_values(v, separator) for v in data.values())
|
178 |
+
|
179 |
+
# If the data is a list, iterate through its elements
|
180 |
+
elif isinstance(data, list):
|
181 |
+
combined_string = separator.join(self.combine_json_values(item, separator) for item in data)
|
182 |
+
|
183 |
+
# For other data types (e.g., numbers), convert to string directly
|
184 |
+
else:
|
185 |
+
combined_string = str(data)
|
186 |
+
|
187 |
+
return combined_string
|
188 |
+
|
189 |
+
def transcribe_image(self, image_file, prompt, max_new_tokens=512, temperature=0.1, top_p=None, num_beams=1):
|
190 |
+
self.monitor.start_monitoring_usage()
|
191 |
+
|
192 |
+
self.image_file = image_file
|
193 |
+
if image_file.startswith('http') or image_file.startswith('https'):
|
194 |
+
response = requests.get(image_file)
|
195 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
196 |
+
else:
|
197 |
+
image = Image.open(image_file).convert('RGB')
|
198 |
+
disable_torch_init()
|
199 |
+
|
200 |
+
qs = prompt
|
201 |
+
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
202 |
+
if IMAGE_PLACEHOLDER in qs:
|
203 |
+
if self.model.config.mm_use_im_start_end:
|
204 |
+
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
|
205 |
+
else:
|
206 |
+
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
|
207 |
+
else:
|
208 |
+
if self.model.config.mm_use_im_start_end:
|
209 |
+
qs = image_token_se + "\n" + qs
|
210 |
+
else:
|
211 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
212 |
+
|
213 |
+
|
214 |
+
conv = conv_templates[self.conv_mode].copy()
|
215 |
+
conv.append_message(conv.roles[0], qs)
|
216 |
+
conv.append_message(conv.roles[1], None)
|
217 |
+
prompt = conv.get_prompt()
|
218 |
+
|
219 |
+
image_files = self.image_parser()
|
220 |
+
images = self.load_images(image_files)
|
221 |
+
image_sizes = [x.size for x in images]
|
222 |
+
images_tensor = process_images(
|
223 |
+
images,
|
224 |
+
self.image_processor,
|
225 |
+
self.model.config
|
226 |
+
).to(self.model.device, dtype=torch.float16)
|
227 |
+
|
228 |
+
input_ids = (
|
229 |
+
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
230 |
+
.unsqueeze(0)
|
231 |
+
.cuda()
|
232 |
+
)
|
233 |
+
|
234 |
+
with torch.inference_mode():
|
235 |
+
output_ids = self.model.generate(
|
236 |
+
input_ids,
|
237 |
+
images=images_tensor,
|
238 |
+
image_sizes=image_sizes,
|
239 |
+
do_sample=True if temperature > 0 else False,
|
240 |
+
temperature=temperature,
|
241 |
+
# top_p=top_p,
|
242 |
+
num_beams=num_beams,
|
243 |
+
max_new_tokens=max_new_tokens,
|
244 |
+
use_cache=True,
|
245 |
+
)
|
246 |
+
|
247 |
+
direct_output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
248 |
+
|
249 |
+
# Parse the output to JSON format using the specified schema.
|
250 |
+
try:
|
251 |
+
json_output = self.parser.parse(direct_output)
|
252 |
+
except:
|
253 |
+
json_output = direct_output
|
254 |
+
|
255 |
+
try:
|
256 |
+
str_output = self.combine_json_values(json_output)
|
257 |
+
except:
|
258 |
+
str_output = direct_output
|
259 |
+
|
260 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
261 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
262 |
+
|
263 |
+
|
264 |
+
return image, json_output, direct_output, str_output, usage_report
|
265 |
+
|
266 |
+
|
267 |
+
PROMPT_OCR = """I need you to transcribe all of the text in this image. Place the transcribed text into a JSON dictionary with this form {"Transcription": "text"}"""
|
268 |
+
PROMPT_ALL = """1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
269 |
+
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
270 |
+
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
271 |
+
4. Duplicate dictionary fields are not allowed.
|
272 |
+
5. Ensure all JSON keys are in camel case.
|
273 |
+
6. Ensure new JSON field values follow sentence case capitalization.
|
274 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
275 |
+
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
276 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
277 |
+
This section provides rules for formatting each JSON value organized by the JSON key.
|
278 |
+
{catalogNumber Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits., order The full scientific name of the order in which the taxon is classified. Order must be capitalized., family The full scientific name of the family in which the taxon is classified. Family must be capitalized., scientificName The scientific name of the taxon including genus, specific epithet, and any lower classifications., scientificNameAuthorship The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode., genus Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'., subgenus The full scientific name of the subgenus in which the taxon is classified. Values should include the genus to avoid homonym confusion., specificEpithet The name of the first or species epithet of the scientificName. Only include the species epithet., infraspecificEpithet The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation., identifiedBy A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector., recordedBy A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen. The primary collector or observer should be listed first., recordNumber An identifier given to the occurrence at the time it was recorded. Often serves as a link between field notes and an occurrence record, such as a specimen collector's number., verbatimEventDate The verbatim original representation of the date and time information for when the specimen was collected. Date of collection exactly as it appears on the label. Do not change the format or correct typos., eventDate Date the specimen was collected formatted as year-month-day, YYYY-MM_DD. If specific components of the date are unknown, they should be replaced with zeros. Examples \0000-00-00\ if the entire date is unknown, \YYYY-00-00\ if only the year is known, and \YYYY-MM-00\ if year and month are known but day is not., habitat A category or description of the habitat in which the specimen collection event occurred., occurrenceRemarks Text describing the specimen's geographic location. Text describing the appearance of the specimen. A statement about the presence or absence of a taxon at a the collection location. Text describing the significance of the specimen, such as a specific expedition or notable collection. Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics., country The name of the country or major administrative unit in which the specimen was originally collected., stateProvince The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected., county The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected., municipality The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected., locality Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or location of the specimen., degreeOfEstablishment Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant. Use either - unknown or cultivated., decimalLatitude Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format., decimalLongitude Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format., verbatimCoordinates Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types include [Lat, Long, UTM, TRS]., minimumElevationInMeters Minimum elevation or altitude in meters. Only if units are explicit then convert from feet (\ft\ or \ft.\\ or \feet\) to meters (\m\ or \m.\ or \meters\). Round to integer., maximumElevationInMeters Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet (\ft\ or \ft.\ or \feet\) to meters (\m\ or \m.\ or \meters\). Round to integer.}
|
279 |
+
Please populate the following JSON dictionary based on the rules and the unformatted OCR text
|
280 |
+
{
|
281 |
+
catalogNumber ,
|
282 |
+
order ,
|
283 |
+
family ,
|
284 |
+
scientificName ,
|
285 |
+
scientificNameAuthorship ,
|
286 |
+
genus ,
|
287 |
+
subgenus ,
|
288 |
+
specificEpithet ,
|
289 |
+
infraspecificEpithet ,
|
290 |
+
identifiedBy ,
|
291 |
+
recordedBy ,
|
292 |
+
recordNumber ,
|
293 |
+
verbatimEventDate ,
|
294 |
+
eventDate ,
|
295 |
+
habitat ,
|
296 |
+
occurrenceRemarks ,
|
297 |
+
country ,
|
298 |
+
stateProvince ,
|
299 |
+
county ,
|
300 |
+
municipality ,
|
301 |
+
locality ,
|
302 |
+
degreeOfEstablishment ,
|
303 |
+
decimalLatitude ,
|
304 |
+
decimalLongitude ,
|
305 |
+
verbatimCoordinates ,
|
306 |
+
minimumElevationInMeters ,
|
307 |
+
maximumElevationInMeters
|
308 |
+
}
|
309 |
+
"""
|
310 |
+
if __name__ == '__main__':
|
311 |
+
logger = logging.getLogger('LLaVA')
|
312 |
+
logger.setLevel(logging.DEBUG)
|
313 |
+
|
314 |
+
OCR_Llava = OCRllava(logger)
|
315 |
+
image, json_output, direct_output, str_output, usage_report = OCR_Llava.transcribe_image("/home/brlab/Dropbox/VoucherVision/demo/demo_images/MICH_16205594_Poaceae_Jouvea_pilosa.jpg",
|
316 |
+
PROMPT_OCR)
|
317 |
+
print('json_output')
|
318 |
+
print(json_output)
|
319 |
+
print('direct_output')
|
320 |
+
print(direct_output)
|
321 |
+
print('str_output')
|
322 |
+
print(str_output)
|
323 |
+
print('usage_report')
|
324 |
+
print(usage_report)
|
vouchervision/VoucherVision_Config_Builder.py
CHANGED
@@ -37,6 +37,10 @@ def build_VV_config(loaded_cfg=None):
|
|
37 |
|
38 |
do_use_trOCR = False
|
39 |
OCR_option = 'hand'
|
|
|
|
|
|
|
|
|
40 |
check_for_illegal_filenames = False
|
41 |
|
42 |
LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
|
@@ -47,6 +51,9 @@ def build_VV_config(loaded_cfg=None):
|
|
47 |
batch_size = 500
|
48 |
num_workers = 8
|
49 |
|
|
|
|
|
|
|
50 |
path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
|
51 |
embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
|
52 |
|
@@ -58,8 +65,8 @@ def build_VV_config(loaded_cfg=None):
|
|
58 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
59 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
60 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
61 |
-
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, save_cropped_annotations,
|
62 |
-
check_for_illegal_filenames, use_domain_knowledge=False)
|
63 |
else:
|
64 |
dir_home = os.path.dirname(os.path.dirname(__file__))
|
65 |
run_name = loaded_cfg['leafmachine']['project']['run_name']
|
@@ -74,6 +81,11 @@ def build_VV_config(loaded_cfg=None):
|
|
74 |
|
75 |
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
76 |
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
|
79 |
prompt_version = loaded_cfg['leafmachine']['project']['prompt_version']
|
@@ -88,19 +100,20 @@ def build_VV_config(loaded_cfg=None):
|
|
88 |
|
89 |
save_cropped_annotations = loaded_cfg['leafmachine']['cropped_components']['save_cropped_annotations']
|
90 |
check_for_illegal_filenames = loaded_cfg['leafmachine']['do']['check_for_illegal_filenames']
|
|
|
91 |
|
92 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
93 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
94 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
95 |
-
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, save_cropped_annotations,
|
96 |
-
check_for_illegal_filenames, use_domain_knowledge=False)
|
97 |
|
98 |
|
99 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
100 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
101 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
102 |
-
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, save_cropped_annotations,
|
103 |
-
check_for_illegal_filenames, use_domain_knowledge=False):
|
104 |
|
105 |
|
106 |
# Initialize the base structure
|
@@ -112,6 +125,7 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
112 |
do_section = {
|
113 |
'check_for_illegal_filenames': check_for_illegal_filenames,
|
114 |
'check_for_corrupt_images_make_vertical': True,
|
|
|
115 |
}
|
116 |
|
117 |
print_section = {
|
@@ -144,6 +158,10 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
144 |
'delete_temps_keep_VVE': False,
|
145 |
'do_use_trOCR': do_use_trOCR,
|
146 |
'OCR_option': OCR_option,
|
|
|
|
|
|
|
|
|
147 |
}
|
148 |
|
149 |
modules_section = {
|
|
|
37 |
|
38 |
do_use_trOCR = False
|
39 |
OCR_option = 'hand'
|
40 |
+
OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
|
41 |
+
OCR_option_llava_bit = 'full' # full or 4bit
|
42 |
+
double_OCR = False
|
43 |
+
|
44 |
check_for_illegal_filenames = False
|
45 |
|
46 |
LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
|
|
|
51 |
batch_size = 500
|
52 |
num_workers = 8
|
53 |
|
54 |
+
skip_vertical = False
|
55 |
+
pdf_conversion_dpi = 100
|
56 |
+
|
57 |
path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
|
58 |
embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
|
59 |
|
|
|
65 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
66 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
67 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
68 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
69 |
+
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
70 |
else:
|
71 |
dir_home = os.path.dirname(os.path.dirname(__file__))
|
72 |
run_name = loaded_cfg['leafmachine']['project']['run_name']
|
|
|
81 |
|
82 |
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
83 |
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
84 |
+
OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
|
85 |
+
OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
|
86 |
+
double_OCR = loaded_cfg['leafmachine']['project']['double_OCR']
|
87 |
+
|
88 |
+
pdf_conversion_dpi = loaded_cfg['leafmachine']['project']['pdf_conversion_dpi']
|
89 |
|
90 |
LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
|
91 |
prompt_version = loaded_cfg['leafmachine']['project']['prompt_version']
|
|
|
100 |
|
101 |
save_cropped_annotations = loaded_cfg['leafmachine']['cropped_components']['save_cropped_annotations']
|
102 |
check_for_illegal_filenames = loaded_cfg['leafmachine']['do']['check_for_illegal_filenames']
|
103 |
+
skip_vertical = loaded_cfg['leafmachine']['do']['skip_vertical']
|
104 |
|
105 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
106 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
107 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
108 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
109 |
+
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
110 |
|
111 |
|
112 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
113 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
114 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
115 |
+
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
116 |
+
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
|
117 |
|
118 |
|
119 |
# Initialize the base structure
|
|
|
125 |
do_section = {
|
126 |
'check_for_illegal_filenames': check_for_illegal_filenames,
|
127 |
'check_for_corrupt_images_make_vertical': True,
|
128 |
+
'skip_vertical': skip_vertical,
|
129 |
}
|
130 |
|
131 |
print_section = {
|
|
|
158 |
'delete_temps_keep_VVE': False,
|
159 |
'do_use_trOCR': do_use_trOCR,
|
160 |
'OCR_option': OCR_option,
|
161 |
+
'OCR_option_llava': OCR_option_llava,
|
162 |
+
'OCR_option_llava_bit': OCR_option_llava_bit,
|
163 |
+
'double_OCR': double_OCR,
|
164 |
+
'pdf_conversion_dpi': pdf_conversion_dpi,
|
165 |
}
|
166 |
|
167 |
modules_section = {
|
vouchervision/data_project.py
CHANGED
@@ -12,6 +12,19 @@ from vouchervision.download_from_GBIF_all_images_in_file import download_all_ima
|
|
12 |
from PIL import Image
|
13 |
from tqdm import tqdm
|
14 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
@dataclass
|
17 |
class Project_Info():
|
@@ -39,6 +52,7 @@ class Project_Info():
|
|
39 |
self.Dirs = Dirs
|
40 |
logger.name = 'Project Info'
|
41 |
logger.info("Gathering Images and Image Metadata")
|
|
|
42 |
|
43 |
self.batch_size = cfg['leafmachine']['project']['batch_size']
|
44 |
|
@@ -90,15 +104,28 @@ class Project_Info():
|
|
90 |
def remove_non_numbers(self, s):
|
91 |
return ''.join([char for char in s if char.isdigit()])
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def copy_images_to_project_dir(self, dir_images, Dirs):
|
94 |
n_total = len(os.listdir(dir_images))
|
95 |
-
for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Copying images to working directory{bcolors.ENDC}',colour="white",position=0,total
|
96 |
-
# Copy og image to new dir
|
97 |
-
# Copied image will be used for all downstream applications
|
98 |
source = os.path.join(dir_images, file)
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def make_file_names_custom(self, dir_images, cfg, Dirs):
|
103 |
n_total = len(os.listdir(dir_images))
|
104 |
for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Creating Catalog Number from file name{bcolors.ENDC}',colour="green",position=0,total = n_total):
|
|
|
12 |
from PIL import Image
|
13 |
from tqdm import tqdm
|
14 |
from pathlib import Path
|
15 |
+
import fitz
|
16 |
+
|
17 |
+
def convert_pdf_to_jpg(source_pdf, destination_dir, dpi=100):
|
18 |
+
doc = fitz.open(source_pdf)
|
19 |
+
for page_num in range(len(doc)):
|
20 |
+
page = doc.load_page(page_num) # Load the current page
|
21 |
+
pix = page.get_pixmap(dpi=dpi) # Render page to an image
|
22 |
+
output_filename = f"{os.path.splitext(os.path.basename(source_pdf))[0]}__{10000 + page_num + 1}.jpg"
|
23 |
+
output_filepath = os.path.join(destination_dir, output_filename)
|
24 |
+
pix.save(output_filepath) # Save the image
|
25 |
+
length_doc = len(doc)
|
26 |
+
doc.close()
|
27 |
+
return length_doc
|
28 |
|
29 |
@dataclass
|
30 |
class Project_Info():
|
|
|
52 |
self.Dirs = Dirs
|
53 |
logger.name = 'Project Info'
|
54 |
logger.info("Gathering Images and Image Metadata")
|
55 |
+
self.logger = logger
|
56 |
|
57 |
self.batch_size = cfg['leafmachine']['project']['batch_size']
|
58 |
|
|
|
104 |
def remove_non_numbers(self, s):
|
105 |
return ''.join([char for char in s if char.isdigit()])
|
106 |
|
107 |
+
# def copy_images_to_project_dir(self, dir_images, Dirs):
|
108 |
+
# n_total = len(os.listdir(dir_images))
|
109 |
+
# for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Copying images to working directory{bcolors.ENDC}',colour="white",position=0,total = n_total):
|
110 |
+
# # Copy og image to new dir
|
111 |
+
# # Copied image will be used for all downstream applications
|
112 |
+
# source = os.path.join(dir_images, file)
|
113 |
+
# destination = os.path.join(Dirs.save_original, file)
|
114 |
+
# shutil.copy(source, destination)
|
115 |
def copy_images_to_project_dir(self, dir_images, Dirs):
|
116 |
n_total = len(os.listdir(dir_images))
|
117 |
+
for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Copying images to working directory{bcolors.ENDC}', colour="white", position=0, total=n_total):
|
|
|
|
|
118 |
source = os.path.join(dir_images, file)
|
119 |
+
# Check if file is a PDF
|
120 |
+
if file.lower().endswith('.pdf'):
|
121 |
+
# Convert PDF pages to JPG images
|
122 |
+
n_pages = convert_pdf_to_jpg(source, Dirs.save_original)
|
123 |
+
self.logger.info(f"Converted {n_pages} pages to JPG from PDF: {file}")
|
124 |
+
else:
|
125 |
+
# Copy non-PDF files directly
|
126 |
+
destination = os.path.join(Dirs.save_original, file)
|
127 |
+
shutil.copy(source, destination)
|
128 |
+
|
129 |
def make_file_names_custom(self, dir_images, cfg, Dirs):
|
130 |
n_total = len(os.listdir(dir_images))
|
131 |
for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Creating Catalog Number from file name{bcolors.ENDC}',colour="green",position=0,total = n_total):
|
vouchervision/general_utils.py
CHANGED
@@ -437,6 +437,7 @@ def split_into_batches(Project, logger, cfg):
|
|
437 |
return Project, n_batches, m
|
438 |
|
439 |
def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
|
|
|
440 |
if cfg['leafmachine']['do']['check_for_corrupt_images_make_vertical']:
|
441 |
n_rotate = 0
|
442 |
n_corrupt = 0
|
@@ -445,10 +446,11 @@ def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
|
|
445 |
if image_name_jpg.endswith((".jpg",".JPG",".jpeg",".JPEG")):
|
446 |
try:
|
447 |
image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
452 |
cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
|
453 |
except:
|
454 |
n_corrupt +=1
|
@@ -457,10 +459,11 @@ def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
|
|
457 |
elif image_name_jpg.endswith((".tiff",".tif",".png",".PNG",".TIFF",".TIF",".jp2",".JP2",".bmp",".BMP",".dib",".DIB")):
|
458 |
try:
|
459 |
image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
|
|
464 |
image_name_jpg = '.'.join([image_name_jpg.split('.')[0], 'jpg'])
|
465 |
cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
|
466 |
except:
|
|
|
437 |
return Project, n_batches, m
|
438 |
|
439 |
def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
|
440 |
+
skip_vertical = cfg['leafmachine']['do']['skip_vertical']
|
441 |
if cfg['leafmachine']['do']['check_for_corrupt_images_make_vertical']:
|
442 |
n_rotate = 0
|
443 |
n_corrupt = 0
|
|
|
446 |
if image_name_jpg.endswith((".jpg",".JPG",".jpeg",".JPEG")):
|
447 |
try:
|
448 |
image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
|
449 |
+
if not skip_vertical:
|
450 |
+
h, w, img_c = image.shape
|
451 |
+
image, img_h, img_w, did_rotate = make_image_vertical(image, h, w, do_rotate_180=False)
|
452 |
+
if did_rotate:
|
453 |
+
n_rotate += 1
|
454 |
cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
|
455 |
except:
|
456 |
n_corrupt +=1
|
|
|
459 |
elif image_name_jpg.endswith((".tiff",".tif",".png",".PNG",".TIFF",".TIF",".jp2",".JP2",".bmp",".BMP",".dib",".DIB")):
|
460 |
try:
|
461 |
image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
|
462 |
+
if not skip_vertical:
|
463 |
+
h, w, img_c = image.shape
|
464 |
+
image, img_h, img_w, did_rotate = make_image_vertical(image, h, w, do_rotate_180=False)
|
465 |
+
if did_rotate:
|
466 |
+
n_rotate += 1
|
467 |
image_name_jpg = '.'.join([image_name_jpg.split('.')[0], 'jpg'])
|
468 |
cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
|
469 |
except:
|
vouchervision/llava_test.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from LLaVA.llava.model.builder import load_pretrained_model
|
2 |
+
from LLaVA.llava.mm_utils import get_model_name_from_path
|
3 |
+
from LLaVA.llava.eval.run_llava import eval_model
|
4 |
+
|
5 |
+
# model_path = "liuhaotian/llava-v1.5-7b"
|
6 |
+
|
7 |
+
# tokenizer, model, image_processor, context_len = load_pretrained_model(
|
8 |
+
# model_path=model_path,
|
9 |
+
# model_base=None,
|
10 |
+
# model_name=get_model_name_from_path(model_path)
|
11 |
+
# )
|
12 |
+
|
13 |
+
# model_path = "liuhaotian/llava-v1.5-7b"
|
14 |
+
# model_path = "liuhaotian/llava-v1.6-mistral-7b"
|
15 |
+
model_path = "liuhaotian/llava-v1.6-34b"
|
16 |
+
prompt = """I need you to transcribe all of the text in this image. Place the transcribed text into a JSON dictionary with this form {"Transcription": "text"}"""
|
17 |
+
# image_file = "https://llava-vl.github.io/static/images/view.jpg"
|
18 |
+
image_file = "/home/brlab/Dropbox/VoucherVision/demo/demo_images/MICH_16205594_Poaceae_Jouvea_pilosa.jpg"
|
19 |
+
args = type('Args', (), {
|
20 |
+
"model_path": model_path,
|
21 |
+
"model_base": None,
|
22 |
+
"model_name": get_model_name_from_path(model_path),
|
23 |
+
"query": prompt,
|
24 |
+
"conv_mode": None,
|
25 |
+
"image_file": image_file,
|
26 |
+
"sep": ",",
|
27 |
+
"temperature": 0,
|
28 |
+
"top_p": None,
|
29 |
+
"num_beams": 1,
|
30 |
+
"max_new_tokens": 512,
|
31 |
+
# "load_8_bit": True,
|
32 |
+
})()
|
33 |
+
|
34 |
+
eval_model(args)
|
vouchervision/utils_LLM.py
CHANGED
@@ -49,7 +49,7 @@ class SystemLoadMonitor():
|
|
49 |
def __init__(self, logger) -> None:
|
50 |
self.monitoring_thread = None
|
51 |
self.logger = logger
|
52 |
-
self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'monitoring': True}
|
53 |
self.start_time = None
|
54 |
self.tool_start_time = None
|
55 |
self.has_GPU = torch.cuda.is_available()
|
@@ -71,11 +71,17 @@ class SystemLoadMonitor():
|
|
71 |
# GPU monitoring
|
72 |
if self.has_GPU:
|
73 |
GPUs = GPUtil.getGPUs()
|
|
|
|
|
|
|
74 |
for gpu in GPUs:
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
79 |
|
80 |
# RAM monitoring
|
81 |
ram_usage = psutil.virtual_memory().used / (1024.0 ** 3) # Get RAM usage in GB
|
@@ -94,46 +100,91 @@ class SystemLoadMonitor():
|
|
94 |
return datetime_iso
|
95 |
|
96 |
def stop_monitoring_report_usage(self):
|
97 |
-
report = {}
|
98 |
-
|
99 |
self.gpu_usage['monitoring'] = False
|
100 |
self.monitoring_thread.join()
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
112 |
}
|
|
|
113 |
self.logger.info(f"Inference Time: {round(self.inference_time,2)} seconds")
|
114 |
self.logger.info(f"Tool Time: {round(tool_time,2)} seconds")
|
115 |
-
|
116 |
self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
|
117 |
-
self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
|
118 |
-
|
119 |
if self.has_GPU:
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'],2)}GB")
|
125 |
else:
|
126 |
-
report.update({'max_gpu_load':
|
127 |
-
report.update({'max_gpu_vram_gb':
|
128 |
|
129 |
return report
|
130 |
|
131 |
-
|
132 |
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
|
|
|
49 |
def __init__(self, logger) -> None:
|
50 |
self.monitoring_thread = None
|
51 |
self.logger = logger
|
52 |
+
self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'n_gpus': 0, 'monitoring': True}
|
53 |
self.start_time = None
|
54 |
self.tool_start_time = None
|
55 |
self.has_GPU = torch.cuda.is_available()
|
|
|
71 |
# GPU monitoring
|
72 |
if self.has_GPU:
|
73 |
GPUs = GPUtil.getGPUs()
|
74 |
+
self.gpu_usage['n_gpus'] = len(GPUs) # Count the number of GPUs
|
75 |
+
total_load = 0
|
76 |
+
total_memory_usage_gb = 0
|
77 |
for gpu in GPUs:
|
78 |
+
total_load += gpu.load
|
79 |
+
total_memory_usage_gb += gpu.memoryUsed / 1024.0
|
80 |
+
|
81 |
+
if self.gpu_usage['n_gpus'] > 0: # Avoid division by zero
|
82 |
+
# Calculate the average load and memory usage across all GPUs
|
83 |
+
self.gpu_usage['max_load'] = max(self.gpu_usage['max_load'], total_load / self.gpu_usage['n_gpus'])
|
84 |
+
self.gpu_usage['max_vram_usage'] = max(self.gpu_usage['max_vram_usage'], total_memory_usage_gb)
|
85 |
|
86 |
# RAM monitoring
|
87 |
ram_usage = psutil.virtual_memory().used / (1024.0 ** 3) # Get RAM usage in GB
|
|
|
100 |
return datetime_iso
|
101 |
|
102 |
def stop_monitoring_report_usage(self):
|
|
|
|
|
103 |
self.gpu_usage['monitoring'] = False
|
104 |
self.monitoring_thread.join()
|
105 |
+
tool_time = time.time() - self.tool_start_time if self.tool_start_time else 0
|
106 |
+
|
107 |
+
num_gpus, gpu_dict, total_vram_gb, capability_score = check_system_gpus()
|
108 |
+
|
109 |
+
report = {
|
110 |
+
'inference_time_s': str(round(self.inference_time, 2)),
|
111 |
+
'tool_time_s': str(round(tool_time, 2)),
|
112 |
+
'max_cpu': str(round(self.gpu_usage['max_cpu_usage'], 2)),
|
113 |
+
'max_ram_gb': str(round(self.gpu_usage['max_ram_usage'], 2)),
|
114 |
+
'current_time': self.get_current_datetime(),
|
115 |
+
'n_gpus': self.gpu_usage['n_gpus'],
|
116 |
+
'total_gpu_vram_gb':total_vram_gb,
|
117 |
+
'capability_score':capability_score,
|
118 |
+
|
119 |
}
|
120 |
+
|
121 |
self.logger.info(f"Inference Time: {round(self.inference_time,2)} seconds")
|
122 |
self.logger.info(f"Tool Time: {round(tool_time,2)} seconds")
|
|
|
123 |
self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
|
124 |
+
self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
|
|
|
125 |
if self.has_GPU:
|
126 |
+
report.update({'max_gpu_load': str(round(self.gpu_usage['max_load'] * 100, 2))})
|
127 |
+
report.update({'max_gpu_vram_gb': str(round(self.gpu_usage['max_vram_usage'], 2))})
|
128 |
+
self.logger.info(f"Max GPU Load: {round(self.gpu_usage['max_load'] * 100, 2)}%")
|
129 |
+
self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'], 2)}GB")
|
|
|
130 |
else:
|
131 |
+
report.update({'max_gpu_load': '0'})
|
132 |
+
report.update({'max_gpu_vram_gb': '0'})
|
133 |
|
134 |
return report
|
135 |
|
|
|
136 |
|
137 |
|
138 |
+
def check_system_gpus():
|
139 |
+
print(f"Torch CUDA: {torch.cuda.is_available()}")
|
140 |
+
# if not torch.cuda.is_available():
|
141 |
+
# return 0, {}, 0, "no_gpu"
|
142 |
+
|
143 |
+
GPUs = GPUtil.getGPUs()
|
144 |
+
num_gpus = len(GPUs)
|
145 |
+
gpu_dict = {}
|
146 |
+
total_vram = 0
|
147 |
+
|
148 |
+
for i, gpu in enumerate(GPUs):
|
149 |
+
gpu_vram = gpu.memoryTotal # VRAM in MB
|
150 |
+
gpu_dict[f"GPU_{i}"] = f"{gpu_vram / 1024} GB" # Convert to GB
|
151 |
+
total_vram += gpu_vram
|
152 |
+
|
153 |
+
total_vram_gb = total_vram / 1024 # Convert total VRAM to GB
|
154 |
+
|
155 |
+
capability_score_map = {
|
156 |
+
"no_gpu": 0,
|
157 |
+
"class_8GB": 10,
|
158 |
+
"class_12GB": 14,
|
159 |
+
"class_16GB": 18,
|
160 |
+
"class_24GB": 26,
|
161 |
+
"class_48GB": 50,
|
162 |
+
"class_96GB": 100,
|
163 |
+
"class_96GBplus": float('inf'), # Use infinity to represent any value greater than 96GB
|
164 |
+
}
|
165 |
+
|
166 |
+
# Determine the capability score based on the total VRAM
|
167 |
+
capability_score = "no_gpu"
|
168 |
+
for score, vram in capability_score_map.items():
|
169 |
+
if total_vram_gb <= vram:
|
170 |
+
capability_score = score
|
171 |
+
break
|
172 |
+
else:
|
173 |
+
capability_score = "class_max"
|
174 |
+
|
175 |
+
return num_gpus, gpu_dict, total_vram_gb, capability_score
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
|
180 |
|
181 |
|
182 |
+
if __name__ == '__main__':
|
183 |
+
num_gpus, gpu_dict, total_vram_gb, capability_score = check_system_gpus()
|
184 |
+
print(f"Number of GPUs: {num_gpus}")
|
185 |
+
print(f"GPU Details: {gpu_dict}")
|
186 |
+
print(f"Total VRAM: {total_vram_gb} GB")
|
187 |
+
print(f"Capability Score: {capability_score}")
|
188 |
|
189 |
|
190 |
|
vouchervision/utils_LLM_JSON_validation.py
CHANGED
@@ -11,7 +11,7 @@ def validate_and_align_JSON_keys_with_template(data, JSON_dict_structure):
|
|
11 |
if value is None:
|
12 |
data[key] = ''
|
13 |
elif isinstance(value, str):
|
14 |
-
if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null',
|
15 |
'not provided in the text', 'not found in the text',
|
16 |
'not in the text', 'not provided', 'not found',
|
17 |
'not provided in the ocr', 'not found in the ocr',
|
|
|
11 |
if value is None:
|
12 |
data[key] = ''
|
13 |
elif isinstance(value, str):
|
14 |
+
if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null', 'unspecified',
|
15 |
'not provided in the text', 'not found in the text',
|
16 |
'not in the text', 'not provided', 'not found',
|
17 |
'not provided in the ocr', 'not found in the ocr',
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -5,10 +5,8 @@ from openpyxl import Workbook, load_workbook
|
|
5 |
import vertexai
|
6 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
7 |
from langchain_openai import AzureChatOpenAI
|
8 |
-
from OCR_google_cloud_vision import OCRGoogle
|
9 |
-
# import google.generativeai as genai
|
10 |
from google.oauth2 import service_account
|
11 |
-
|
12 |
|
13 |
from vouchervision.LLM_OpenAI import OpenAIHandler
|
14 |
from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
|
@@ -20,6 +18,7 @@ from vouchervision.utils_LLM import remove_colons_and_double_apostrophes
|
|
20 |
from vouchervision.prompt_catalog import PromptCatalog
|
21 |
from vouchervision.model_maps import ModelMaps
|
22 |
from vouchervision.general_utils import get_cfg_from_full_path
|
|
|
23 |
|
24 |
'''
|
25 |
* For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
|
@@ -44,9 +43,11 @@ class VoucherVision():
|
|
44 |
self.prompt_version = None
|
45 |
self.is_hf = is_hf
|
46 |
|
47 |
-
|
48 |
-
self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
49 |
-
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask"
|
|
|
|
|
50 |
self.trOCR_processor = None
|
51 |
self.trOCR_model = None
|
52 |
|
@@ -77,12 +78,12 @@ class VoucherVision():
|
|
77 |
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
78 |
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
79 |
|
80 |
-
self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "max_gpu_load", "max_gpu_vram_gb",]
|
81 |
|
82 |
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
83 |
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
84 |
|
85 |
-
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
86 |
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
87 |
|
88 |
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
@@ -117,8 +118,8 @@ class VoucherVision():
|
|
117 |
lgr = logging.getLogger('transformers')
|
118 |
lgr.setLevel(logging.ERROR)
|
119 |
|
120 |
-
self.trOCR_processor = TrOCRProcessor.from_pretrained(
|
121 |
-
self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version)
|
122 |
|
123 |
# Check for GPU availability
|
124 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -297,7 +298,7 @@ class VoucherVision():
|
|
297 |
break
|
298 |
|
299 |
|
300 |
-
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
301 |
|
302 |
|
303 |
wb = openpyxl.load_workbook(path_transcription)
|
@@ -364,6 +365,8 @@ class VoucherVision():
|
|
364 |
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
365 |
elif header.value == "prompt":
|
366 |
sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
|
|
|
|
|
367 |
|
368 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
369 |
elif header.value in self.wfo_headers_no_lists:
|
@@ -613,12 +616,12 @@ class VoucherVision():
|
|
613 |
##################################################################################################################################
|
614 |
################################################## OCR ##################################################################
|
615 |
##################################################################################################################################
|
616 |
-
def perform_OCR_and_save_results(self, image_index, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
|
617 |
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
|
618 |
# self.OCR - None
|
619 |
|
620 |
### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
|
621 |
-
ocr_google =
|
622 |
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
623 |
self.OCR = ocr_google.OCR
|
624 |
|
@@ -682,7 +685,7 @@ class VoucherVision():
|
|
682 |
|
683 |
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
684 |
json_report.set_text(text_main='Starting OCR')
|
685 |
-
self.perform_OCR_and_save_results(i, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
686 |
json_report.set_text(text_main='Finished OCR')
|
687 |
|
688 |
if not self.OCR:
|
@@ -797,10 +800,10 @@ class VoucherVision():
|
|
797 |
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
798 |
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
799 |
if response_candidate is not None:
|
800 |
-
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
801 |
return final_JSON_response_updated, WFO_record, GEO_record
|
802 |
else:
|
803 |
-
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
804 |
return final_JSON_response_updated, WFO_record, GEO_record
|
805 |
|
806 |
|
@@ -836,7 +839,7 @@ class VoucherVision():
|
|
836 |
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
|
837 |
|
838 |
|
839 |
-
def save_json_and_xlsx(self, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
840 |
if response is None:
|
841 |
response = self.JSON_dict_structure
|
842 |
# Insert 'filename' as the first key
|
@@ -845,14 +848,14 @@ class VoucherVision():
|
|
845 |
|
846 |
# Then add the null info to the spreadsheet
|
847 |
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
848 |
-
self.add_data_to_excel_from_response(self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
849 |
|
850 |
### Set completed JSON
|
851 |
else:
|
852 |
response = self.clean_catalog_number(response, filename_without_extension)
|
853 |
self.write_json_to_file(txt_file_path, response)
|
854 |
# add to the xlsx file
|
855 |
-
self.add_data_to_excel_from_response(self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
856 |
return response
|
857 |
|
858 |
|
|
|
5 |
import vertexai
|
6 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
7 |
from langchain_openai import AzureChatOpenAI
|
|
|
|
|
8 |
from google.oauth2 import service_account
|
9 |
+
from transformers import AutoTokenizer, AutoModel
|
10 |
|
11 |
from vouchervision.LLM_OpenAI import OpenAIHandler
|
12 |
from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
|
|
|
18 |
from vouchervision.prompt_catalog import PromptCatalog
|
19 |
from vouchervision.model_maps import ModelMaps
|
20 |
from vouchervision.general_utils import get_cfg_from_full_path
|
21 |
+
from vouchervision.OCR_google_cloud_vision import OCREngine
|
22 |
|
23 |
'''
|
24 |
* For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
|
|
|
43 |
self.prompt_version = None
|
44 |
self.is_hf = is_hf
|
45 |
|
46 |
+
self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
47 |
+
# self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
48 |
+
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
|
49 |
+
# self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
|
50 |
+
# self.trOCR_model_version = "DunnBC22/trocr-base-handwritten-OCR-handwriting_recognition_v2" # NOPE
|
51 |
self.trOCR_processor = None
|
52 |
self.trOCR_model = None
|
53 |
|
|
|
78 |
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
79 |
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
80 |
|
81 |
+
self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "n_gpus", "max_gpu_load", "max_gpu_vram_gb","total_gpu_vram_gb","capability_score",]
|
82 |
|
83 |
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
84 |
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
85 |
|
86 |
+
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
87 |
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
88 |
|
89 |
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
|
|
118 |
lgr = logging.getLogger('transformers')
|
119 |
lgr.setLevel(logging.ERROR)
|
120 |
|
121 |
+
self.trOCR_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") # usually just the "microsoft/trocr-base-handwritten"
|
122 |
+
self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version) # This matches the model
|
123 |
|
124 |
# Check for GPU availability
|
125 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
298 |
break
|
299 |
|
300 |
|
301 |
+
def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
302 |
|
303 |
|
304 |
wb = openpyxl.load_workbook(path_transcription)
|
|
|
365 |
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
366 |
elif header.value == "prompt":
|
367 |
sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
|
368 |
+
elif header.value == "run_name":
|
369 |
+
sheet.cell(row=next_row, column=i, value=Dirs.run_name)
|
370 |
|
371 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
372 |
elif header.value in self.wfo_headers_no_lists:
|
|
|
616 |
##################################################################################################################################
|
617 |
################################################## OCR ##################################################################
|
618 |
##################################################################################################################################
|
619 |
+
def perform_OCR_and_save_results(self, image_index, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
|
620 |
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
|
621 |
# self.OCR - None
|
622 |
|
623 |
### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
|
624 |
+
ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
|
625 |
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
626 |
self.OCR = ocr_google.OCR
|
627 |
|
|
|
685 |
|
686 |
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
687 |
json_report.set_text(text_main='Starting OCR')
|
688 |
+
self.perform_OCR_and_save_results(i, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
689 |
json_report.set_text(text_main='Finished OCR')
|
690 |
|
691 |
if not self.OCR:
|
|
|
800 |
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
801 |
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
802 |
if response_candidate is not None:
|
803 |
+
final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
804 |
return final_JSON_response_updated, WFO_record, GEO_record
|
805 |
else:
|
806 |
+
final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
807 |
return final_JSON_response_updated, WFO_record, GEO_record
|
808 |
|
809 |
|
|
|
839 |
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
|
840 |
|
841 |
|
842 |
+
def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
843 |
if response is None:
|
844 |
response = self.JSON_dict_structure
|
845 |
# Insert 'filename' as the first key
|
|
|
848 |
|
849 |
# Then add the null info to the spreadsheet
|
850 |
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
851 |
+
self.add_data_to_excel_from_response(Dirs, self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
852 |
|
853 |
### Set completed JSON
|
854 |
else:
|
855 |
response = self.clean_catalog_number(response, filename_without_extension)
|
856 |
self.write_json_to_file(txt_file_path, response)
|
857 |
# add to the xlsx file
|
858 |
+
self.add_data_to_excel_from_response(Dirs, self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
859 |
return response
|
860 |
|
861 |
|
vouchervision/vouchervision_main.py
CHANGED
@@ -3,10 +3,10 @@ VoucherVision - based on LeafMachine2 Processes
|
|
3 |
'''
|
4 |
import os, inspect, sys, shutil
|
5 |
from time import perf_counter
|
6 |
-
currentdir = os.path.dirname(os.path.dirname(inspect.getfile(inspect.currentframe())))
|
7 |
-
parentdir = os.path.dirname(currentdir)
|
8 |
-
sys.path.append(parentdir)
|
9 |
-
sys.path.append(currentdir)
|
10 |
from vouchervision.component_detector.component_detector import detect_plant_components, detect_archival_components
|
11 |
from vouchervision.general_utils import save_token_info_as_csv, print_main_start, check_for_subdirs_VV, load_config_file, load_config_file_testing, report_config, save_config_file, crop_detections_from_images_VV
|
12 |
from vouchervision.directory_structure_VV import Dir_Structure
|
@@ -90,7 +90,14 @@ def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progr
|
|
90 |
else:
|
91 |
upload_to_drive(zip_filepath, zip_filename, is_hf, cfg_private=Voucher_Vision.cfg_private, do_upload=False) ##################################### TODO Make this configurable
|
92 |
|
93 |
-
return last_JSON_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def make_zipfile(base_dir, output_filename):
|
96 |
# Determine the directory where the zip file should be saved
|
|
|
3 |
'''
|
4 |
import os, inspect, sys, shutil
|
5 |
from time import perf_counter
|
6 |
+
# currentdir = os.path.dirname(os.path.dirname(inspect.getfile(inspect.currentframe())))
|
7 |
+
# parentdir = os.path.dirname(currentdir)
|
8 |
+
# sys.path.append(parentdir)
|
9 |
+
# sys.path.append(currentdir)
|
10 |
from vouchervision.component_detector.component_detector import detect_plant_components, detect_archival_components
|
11 |
from vouchervision.general_utils import save_token_info_as_csv, print_main_start, check_for_subdirs_VV, load_config_file, load_config_file_testing, report_config, save_config_file, crop_detections_from_images_VV
|
12 |
from vouchervision.directory_structure_VV import Dir_Structure
|
|
|
90 |
else:
|
91 |
upload_to_drive(zip_filepath, zip_filename, is_hf, cfg_private=Voucher_Vision.cfg_private, do_upload=False) ##################################### TODO Make this configurable
|
92 |
|
93 |
+
return {'last_JSON_response': last_JSON_response,
|
94 |
+
'final_WFO_record': final_WFO_record,
|
95 |
+
'final_GEO_record': final_GEO_record,
|
96 |
+
'total_cost': total_cost,
|
97 |
+
'n_failed_OCR': Voucher_Vision.n_failed_OCR,
|
98 |
+
'n_failed_LLM_calls': Voucher_Vision.n_failed_LLM_calls,
|
99 |
+
'zip_filepath': zip_filepath,
|
100 |
+
}
|
101 |
|
102 |
def make_zipfile(base_dir, output_filename):
|
103 |
# Determine the directory where the zip file should be saved
|