Spaces:
Running
Running
phyloforfun
commited on
Commit
•
ae215ea
1
Parent(s):
37a138a
Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing
Browse files- app.py +82 -29
- run_VoucherVision.py +2 -2
- vouchervision/API_validation.py +5 -8
- vouchervision/LLM_GoogleGemini.py +23 -23
- vouchervision/LLM_GooglePalm2.py +10 -11
- vouchervision/LLM_MistralAI.py +10 -11
- vouchervision/LLM_OpenAI.py +14 -10
- vouchervision/LLM_local_MistralAI.py +10 -11
- vouchervision/LLM_local_cpu_MistralAI.py +10 -11
- vouchervision/OCR_Gemini.py +3 -3
- vouchervision/OCR_google_cloud_vision.py +48 -276
- vouchervision/OCR_llava.py +9 -9
- vouchervision/VoucherVision_Config_Builder.py +28 -7
- vouchervision/model_maps.py +1 -1
- vouchervision/tool_geolocate_HERE.py +321 -0
- vouchervision/tool_taxonomy_WFO.py +324 -0
- vouchervision/tool_wikipedia.py +51 -41
- vouchervision/utils_LLM.py +64 -0
- vouchervision/utils_VoucherVision.py +38 -25
app.py
CHANGED
@@ -7,6 +7,7 @@ import pandas as pd
|
|
7 |
from io import BytesIO
|
8 |
from streamlit_extras.let_it_rain import rain
|
9 |
from annotated_text import annotated_text
|
|
|
10 |
|
11 |
from vouchervision.LeafMachine2_Config_Builder import write_config_file
|
12 |
from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
|
@@ -999,7 +1000,8 @@ def create_private_file():
|
|
999 |
st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
|
1000 |
st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface or by manually editing the `.yaml` file in a text editor.")
|
1001 |
st.write("Leave keys blank if you do not intend to use that service.")
|
1002 |
-
|
|
|
1003 |
st.write("---")
|
1004 |
st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
|
1005 |
st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
|
@@ -1008,46 +1010,46 @@ def create_private_file():
|
|
1008 |
with st.expander("**View Google API Instructions**"):
|
1009 |
|
1010 |
blog_text_and_image(text="Select your project, then in the search bar, search for `vertex ai` and select the option in the photo below.",
|
1011 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_00.
|
1012 |
|
1013 |
blog_text_and_image(text="On the main overview page, click `Enable All Recommended APIs`. Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
|
1014 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_0.
|
1015 |
|
1016 |
blog_text_and_image(text="Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
|
1017 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_2.
|
1018 |
|
1019 |
blog_text_and_image(text="Make sure that all APIs are enabled.",
|
1020 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_1.
|
1021 |
|
1022 |
blog_text_and_image(text="Find the `Vision AI API` service and go to its page.",
|
1023 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_3.
|
1024 |
|
1025 |
blog_text_and_image(text="Find the `Vision AI API` service and go to its page. This is the API service required to use OCR in VoucherVision and must be enabled.",
|
1026 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_6.
|
1027 |
|
1028 |
blog_text_and_image(text="You can also search for the Vertex AI Vision service.",
|
1029 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_4.
|
1030 |
|
1031 |
blog_text_and_image(text=None,
|
1032 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_5.
|
1033 |
|
1034 |
st.subheader("Getting a Google JSON authentication key")
|
1035 |
st.write("Google uses a JSON file to store additional authentication information. Save this file in a safe, private location and assign the `GOOGLE_APPLICATION_CREDENTIALS` value to the file path. For Hugging Face, copy the contents of the JSON file including the `\{\}` and paste it as the secret value.")
|
1036 |
st.write("To download your JSON key...")
|
1037 |
blog_text_and_image(text="Open the navigation menu. Click on the hamburger menu (three horizontal lines) in the top left corner. Go to IAM & Admin. ",
|
1038 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_7.
|
1039 |
|
1040 |
blog_text_and_image(text="In the navigation pane, hover over `IAM & Admin` and then click on `Service accounts`.",
|
1041 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_8.
|
1042 |
|
1043 |
blog_text_and_image(text="Find the default Compute Engine service account, select it.",
|
1044 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_9.
|
1045 |
|
1046 |
blog_text_and_image(text="Click `Add Key`.",
|
1047 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_10.
|
1048 |
|
1049 |
blog_text_and_image(text="Select `JSON` and click create. This will download your key. Store this in a safe location. The file path to this safe location is the value that you enter into the `GOOGLE_APPLICATION_CREDENTIALS` value.",
|
1050 |
-
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_11.
|
1051 |
|
1052 |
blog_text(text_bold="Store Safely", text=": This file contains sensitive data that can be used to authenticate and bill your Google Cloud account. Never commit it to public repositories or expose it in any way. Always keep it safe and secure.")
|
1053 |
|
@@ -1135,21 +1137,24 @@ def create_private_file():
|
|
1135 |
st.write("---")
|
1136 |
st.subheader("HERE Geocoding")
|
1137 |
st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
|
1138 |
-
|
1139 |
help='e.g. a 32-character string',
|
1140 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
1141 |
type='password')
|
1142 |
-
|
1143 |
help='e.g. a 32-character string',
|
1144 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
1145 |
type='password')
|
1146 |
|
1147 |
|
1148 |
|
1149 |
-
st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
|
|
|
|
|
|
1153 |
if st.button('Proceed to VoucherVision'):
|
1154 |
st.session_state.private_file = does_private_file_exist()
|
1155 |
st.session_state.proceed_to_private = False
|
@@ -1157,10 +1162,12 @@ def create_private_file():
|
|
1157 |
st.rerun()
|
1158 |
|
1159 |
|
1160 |
-
def save_changes_to_API_keys(cfg_private,
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
|
|
|
|
1164 |
|
1165 |
# Update the configuration dictionary with the new values
|
1166 |
cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
|
@@ -1172,15 +1179,16 @@ def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version
|
|
1172 |
cfg_private['openai_azure']['OPENAI_API_TYPE'] = azure_openai_api_type
|
1173 |
|
1174 |
cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'] = google_application_credentials
|
1175 |
-
cfg_private['google']['GOOGLE_PROJECT_ID'] =
|
1176 |
-
cfg_private['google']['GOOGLE_LOCATION'] =
|
1177 |
|
1178 |
cfg_private['mistral']['MISTRAL_API_KEY'] = mistral_API_KEY
|
1179 |
|
1180 |
-
cfg_private['here']['APP_ID'] =
|
1181 |
-
cfg_private['here']['API_KEY'] =
|
1182 |
# Call the function to write the updated configuration to the YAML file
|
1183 |
write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
|
|
|
1184 |
# st.session_state.private_file = does_private_file_exist()
|
1185 |
|
1186 |
# Function to load a YAML file and update session_state
|
@@ -1568,6 +1576,25 @@ def content_project_settings(col):
|
|
1568 |
st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
|
1569 |
|
1570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1571 |
def content_llm_cost():
|
1572 |
st.write("---")
|
1573 |
st.header('LLM Cost Calculator')
|
@@ -1855,6 +1882,17 @@ def content_ocr_method():
|
|
1855 |
do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],key="Enable trOCR2")#,disabled=st.session_state['lacks_GPU'])
|
1856 |
st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
|
1857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1858 |
if 'LLaVA' in selected_OCR_options:
|
1859 |
OCR_option_llava = st.radio(
|
1860 |
"Select the LLaVA version",
|
@@ -1888,6 +1926,15 @@ def content_ocr_method():
|
|
1888 |
# elif (OCR_option == 'hand') and do_use_trOCR:
|
1889 |
# st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
|
1890 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1891 |
@st.cache_data
|
1892 |
def show_collage():
|
1893 |
# Load the image only if it's not already in the session state
|
@@ -1920,7 +1967,12 @@ def content_collage_overlay():
|
|
1920 |
st.info("NOTE: We strongly recommend enabling LeafMachine2 cropping if your images are full sized herbarium sheet. Often, the OCR algorithm struggles with full sheets, but works well with the collage images. We have disabled the collage by default for this Hugging Face Space because the Space lacks a GPU and the collage creation takes a bit longer.")
|
1921 |
default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
|
1922 |
st.markdown("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
|
1923 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
1924 |
|
1925 |
|
1926 |
option_selected_crops = st.multiselect(label="Components to crop",
|
@@ -2247,6 +2299,7 @@ def main():
|
|
2247 |
content_ocr_method()
|
2248 |
|
2249 |
content_collage_overlay()
|
|
|
2250 |
content_llm_cost()
|
2251 |
content_processing_options()
|
2252 |
content_less_used()
|
|
|
7 |
from io import BytesIO
|
8 |
from streamlit_extras.let_it_rain import rain
|
9 |
from annotated_text import annotated_text
|
10 |
+
from transformers import AutoConfig
|
11 |
|
12 |
from vouchervision.LeafMachine2_Config_Builder import write_config_file
|
13 |
from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
|
|
|
1000 |
st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
|
1001 |
st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface or by manually editing the `.yaml` file in a text editor.")
|
1002 |
st.write("Leave keys blank if you do not intend to use that service.")
|
1003 |
+
st.info("Note: You can manually edit these API keys later by opening the /PRIVATE_DATA.yaml file in a plain text editor.")
|
1004 |
+
|
1005 |
st.write("---")
|
1006 |
st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
|
1007 |
st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
|
|
|
1010 |
with st.expander("**View Google API Instructions**"):
|
1011 |
|
1012 |
blog_text_and_image(text="Select your project, then in the search bar, search for `vertex ai` and select the option in the photo below.",
|
1013 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_00.PNG'))
|
1014 |
|
1015 |
blog_text_and_image(text="On the main overview page, click `Enable All Recommended APIs`. Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
|
1016 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_0.PNG'))
|
1017 |
|
1018 |
blog_text_and_image(text="Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
|
1019 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_2.PNG'))
|
1020 |
|
1021 |
blog_text_and_image(text="Make sure that all APIs are enabled.",
|
1022 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_1.PNG'))
|
1023 |
|
1024 |
blog_text_and_image(text="Find the `Vision AI API` service and go to its page.",
|
1025 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_3.PNG'))
|
1026 |
|
1027 |
blog_text_and_image(text="Find the `Vision AI API` service and go to its page. This is the API service required to use OCR in VoucherVision and must be enabled.",
|
1028 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_6.PNG'))
|
1029 |
|
1030 |
blog_text_and_image(text="You can also search for the Vertex AI Vision service.",
|
1031 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_4.PNG'))
|
1032 |
|
1033 |
blog_text_and_image(text=None,
|
1034 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_5.PNG'))
|
1035 |
|
1036 |
st.subheader("Getting a Google JSON authentication key")
|
1037 |
st.write("Google uses a JSON file to store additional authentication information. Save this file in a safe, private location and assign the `GOOGLE_APPLICATION_CREDENTIALS` value to the file path. For Hugging Face, copy the contents of the JSON file including the `\{\}` and paste it as the secret value.")
|
1038 |
st.write("To download your JSON key...")
|
1039 |
blog_text_and_image(text="Open the navigation menu. Click on the hamburger menu (three horizontal lines) in the top left corner. Go to IAM & Admin. ",
|
1040 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_7.PNG'),width=300)
|
1041 |
|
1042 |
blog_text_and_image(text="In the navigation pane, hover over `IAM & Admin` and then click on `Service accounts`.",
|
1043 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_8.PNG'))
|
1044 |
|
1045 |
blog_text_and_image(text="Find the default Compute Engine service account, select it.",
|
1046 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_9.PNG'))
|
1047 |
|
1048 |
blog_text_and_image(text="Click `Add Key`.",
|
1049 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_10.PNG'))
|
1050 |
|
1051 |
blog_text_and_image(text="Select `JSON` and click create. This will download your key. Store this in a safe location. The file path to this safe location is the value that you enter into the `GOOGLE_APPLICATION_CREDENTIALS` value.",
|
1052 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_11.PNG'))
|
1053 |
|
1054 |
blog_text(text_bold="Store Safely", text=": This file contains sensitive data that can be used to authenticate and bill your Google Cloud account. Never commit it to public repositories or expose it in any way. Always keep it safe and secure.")
|
1055 |
|
|
|
1137 |
st.write("---")
|
1138 |
st.subheader("HERE Geocoding")
|
1139 |
st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
|
1140 |
+
here_APP_ID = st.text_input("HERE Geocoding App ID", cfg_private['here'].get('APP_ID', ''),
|
1141 |
help='e.g. a 32-character string',
|
1142 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
1143 |
type='password')
|
1144 |
+
here_API_KEY = st.text_input("HERE Geocoding API Key", cfg_private['here'].get('API_KEY', ''),
|
1145 |
help='e.g. a 32-character string',
|
1146 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
1147 |
type='password')
|
1148 |
|
1149 |
|
1150 |
|
1151 |
+
st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
|
1152 |
+
args=[cfg_private,
|
1153 |
+
openai_api_key,
|
1154 |
+
azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
|
1155 |
+
google_application_credentials, google_project_location, google_project_id,
|
1156 |
+
mistral_API_KEY,
|
1157 |
+
here_APP_ID, here_API_KEY])
|
1158 |
if st.button('Proceed to VoucherVision'):
|
1159 |
st.session_state.private_file = does_private_file_exist()
|
1160 |
st.session_state.proceed_to_private = False
|
|
|
1162 |
st.rerun()
|
1163 |
|
1164 |
|
1165 |
+
def save_changes_to_API_keys(cfg_private,
|
1166 |
+
openai_api_key,
|
1167 |
+
azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
|
1168 |
+
google_application_credentials, google_project_location, google_project_id,
|
1169 |
+
mistral_API_KEY,
|
1170 |
+
here_APP_ID, here_API_KEY):
|
1171 |
|
1172 |
# Update the configuration dictionary with the new values
|
1173 |
cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
|
|
|
1179 |
cfg_private['openai_azure']['OPENAI_API_TYPE'] = azure_openai_api_type
|
1180 |
|
1181 |
cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'] = google_application_credentials
|
1182 |
+
cfg_private['google']['GOOGLE_PROJECT_ID'] = google_project_id
|
1183 |
+
cfg_private['google']['GOOGLE_LOCATION'] = google_project_location
|
1184 |
|
1185 |
cfg_private['mistral']['MISTRAL_API_KEY'] = mistral_API_KEY
|
1186 |
|
1187 |
+
cfg_private['here']['APP_ID'] = here_APP_ID
|
1188 |
+
cfg_private['here']['API_KEY'] = here_API_KEY
|
1189 |
# Call the function to write the updated configuration to the YAML file
|
1190 |
write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
|
1191 |
+
st.success(f"API Keys saved to {os.path.join(st.session_state.dir_home, 'PRIVATE_DATA.yaml')}")
|
1192 |
# st.session_state.private_file = does_private_file_exist()
|
1193 |
|
1194 |
# Function to load a YAML file and update session_state
|
|
|
1576 |
st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
|
1577 |
|
1578 |
|
1579 |
+
def content_tools():
|
1580 |
+
st.write("---")
|
1581 |
+
st.header('Validation Tools')
|
1582 |
+
|
1583 |
+
tool_WFO = st.session_state.config['leafmachine']['project']['tool_WFO']
|
1584 |
+
st.session_state.config['leafmachine']['project']['tool_WFO'] = st.checkbox(label="Enable World Flora Online taxonomy verification",
|
1585 |
+
help="",
|
1586 |
+
value=tool_WFO)
|
1587 |
+
|
1588 |
+
tool_GEO = st.session_state.config['leafmachine']['project']['tool_GEO']
|
1589 |
+
st.session_state.config['leafmachine']['project']['tool_GEO'] = st.checkbox(label="Enable HERE geolocation hints",
|
1590 |
+
help="",
|
1591 |
+
value=tool_GEO)
|
1592 |
+
|
1593 |
+
tool_wikipedia = st.session_state.config['leafmachine']['project']['tool_wikipedia']
|
1594 |
+
st.session_state.config['leafmachine']['project']['tool_wikipedia'] = st.checkbox(label="Enable Wikipedia verification",
|
1595 |
+
help="",
|
1596 |
+
value=tool_wikipedia)
|
1597 |
+
|
1598 |
def content_llm_cost():
|
1599 |
st.write("---")
|
1600 |
st.header('LLM Cost Calculator')
|
|
|
1882 |
do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],key="Enable trOCR2")#,disabled=st.session_state['lacks_GPU'])
|
1883 |
st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
|
1884 |
|
1885 |
+
if do_use_trOCR:
|
1886 |
+
# st.session_state.config['leafmachine']['project']['trOCR_model_path'] = "microsoft/trocr-large-handwritten"
|
1887 |
+
default_trOCR_model_path = st.session_state.config['leafmachine']['project']['trOCR_model_path']
|
1888 |
+
user_input_trOCR_model_path = st.text_input("trOCR Hugging Face model path. MUST be a fine-tuned version of 'microsoft/trocr-base-handwritten' or 'microsoft/trocr-large-handwritten', or a microsoft trOCR model.", value=default_trOCR_model_path)
|
1889 |
+
if st.session_state.config['leafmachine']['project']['trOCR_model_path'] != user_input_trOCR_model_path:
|
1890 |
+
is_valid_mp = is_valid_huggingface_model_path(user_input_trOCR_model_path)
|
1891 |
+
if not is_valid_mp:
|
1892 |
+
st.error(f"The Hugging Face model path {user_input_trOCR_model_path} is not valid. Please revise.")
|
1893 |
+
else:
|
1894 |
+
st.session_state.config['leafmachine']['project']['trOCR_model_path'] = user_input_trOCR_model_path
|
1895 |
+
|
1896 |
if 'LLaVA' in selected_OCR_options:
|
1897 |
OCR_option_llava = st.radio(
|
1898 |
"Select the LLaVA version",
|
|
|
1926 |
# elif (OCR_option == 'hand') and do_use_trOCR:
|
1927 |
# st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
|
1928 |
|
1929 |
+
def is_valid_huggingface_model_path(model_path):
|
1930 |
+
try:
|
1931 |
+
# Attempt to load the model configuration from Hugging Face Model Hub
|
1932 |
+
config = AutoConfig.from_pretrained(model_path)
|
1933 |
+
return True # If the configuration loads successfully, the model path is valid
|
1934 |
+
except Exception as e:
|
1935 |
+
# If loading the model configuration fails, the model path is not valid
|
1936 |
+
return False
|
1937 |
+
|
1938 |
@st.cache_data
|
1939 |
def show_collage():
|
1940 |
# Load the image only if it's not already in the session state
|
|
|
1967 |
st.info("NOTE: We strongly recommend enabling LeafMachine2 cropping if your images are full sized herbarium sheet. Often, the OCR algorithm struggles with full sheets, but works well with the collage images. We have disabled the collage by default for this Hugging Face Space because the Space lacks a GPU and the collage creation takes a bit longer.")
|
1968 |
default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
|
1969 |
st.markdown("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
|
1970 |
+
st.markdown("Images that are mostly text (like a scanned notecard, or already cropped images) do not require LM2 collage.")
|
1971 |
+
|
1972 |
+
if st.session_state.is_hf:
|
1973 |
+
st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', False), key='do make collage hf')
|
1974 |
+
else:
|
1975 |
+
st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', True), key='do make collage local')
|
1976 |
|
1977 |
|
1978 |
option_selected_crops = st.multiselect(label="Components to crop",
|
|
|
2299 |
content_ocr_method()
|
2300 |
|
2301 |
content_collage_overlay()
|
2302 |
+
content_tools()
|
2303 |
content_llm_cost()
|
2304 |
content_processing_options()
|
2305 |
content_less_used()
|
run_VoucherVision.py
CHANGED
@@ -31,7 +31,7 @@ def resolve_path(path):
|
|
31 |
if __name__ == "__main__":
|
32 |
dir_home = os.path.dirname(__file__)
|
33 |
|
34 |
-
start_port =
|
35 |
try:
|
36 |
free_port = find_available_port(start_port)
|
37 |
sys.argv = [
|
@@ -41,7 +41,7 @@ if __name__ == "__main__":
|
|
41 |
# resolve_path(os.path.join(dir_home,"vouchervision", "VoucherVision_GUI.py")),
|
42 |
"--global.developmentMode=false",
|
43 |
# "--server.port=8545",
|
44 |
-
"--server.port=
|
45 |
# Toggle below for HF vs Local
|
46 |
# "--is_hf=1",
|
47 |
# "--is_hf=0",
|
|
|
31 |
if __name__ == "__main__":
|
32 |
dir_home = os.path.dirname(__file__)
|
33 |
|
34 |
+
start_port = 8528
|
35 |
try:
|
36 |
free_port = find_available_port(start_port)
|
37 |
sys.argv = [
|
|
|
41 |
# resolve_path(os.path.join(dir_home,"vouchervision", "VoucherVision_GUI.py")),
|
42 |
"--global.developmentMode=false",
|
43 |
# "--server.port=8545",
|
44 |
+
f"--server.port={free_port}",
|
45 |
# Toggle below for HF vs Local
|
46 |
# "--is_hf=1",
|
47 |
# "--is_hf=0",
|
vouchervision/API_validation.py
CHANGED
@@ -36,10 +36,11 @@ class APIvalidation:
|
|
36 |
|
37 |
|
38 |
def has_API_key(self, val):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
def check_openai_api_key(self):
|
45 |
if self.is_hf:
|
@@ -192,10 +193,6 @@ class APIvalidation:
|
|
192 |
print(f"palm2 fail2")
|
193 |
|
194 |
try:
|
195 |
-
# https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm
|
196 |
-
# os.environ['GOOGLE_API_KEY'] = "AIzaSyAHOH1w1qV7C3jS4W7QFyoaTGUwZIgS5ig"
|
197 |
-
# genai.configure(api_key='AIzaSyC8xvu6t9fb5dTah3hpgg_rwwR5G5kianI')
|
198 |
-
# model = ChatGoogleGenerativeAI(model="text-bison@001")
|
199 |
model = VertexAI(model="text-bison@001", max_output_tokens=10)
|
200 |
response = model.predict("Hello")
|
201 |
test_response_palm2 = response
|
|
|
36 |
|
37 |
|
38 |
def has_API_key(self, val):
|
39 |
+
return isinstance(val, str) and bool(val.strip())
|
40 |
+
# if val:
|
41 |
+
# return True
|
42 |
+
# else:
|
43 |
+
# return False
|
44 |
|
45 |
def check_openai_api_key(self):
|
46 |
if self.is_hf:
|
|
|
193 |
print(f"palm2 fail2")
|
194 |
|
195 |
try:
|
|
|
|
|
|
|
|
|
196 |
model = VertexAI(model="text-bison@001", max_output_tokens=10)
|
197 |
response = model.predict("Hello")
|
198 |
test_response_palm2 = response
|
vouchervision/LLM_GoogleGemini.py
CHANGED
@@ -6,14 +6,11 @@ from langchain.output_parsers import RetryWithErrorOutputParser
|
|
6 |
# from langchain.schema import HumanMessage
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain_core.output_parsers import JsonOutputParser
|
9 |
-
|
10 |
from langchain_google_vertexai import VertexAI
|
11 |
|
12 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
13 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
14 |
-
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
15 |
-
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
16 |
-
from vouchervision.tool_wikipedia import WikipediaLinks
|
17 |
|
18 |
class GoogleGeminiHandler:
|
19 |
|
@@ -23,7 +20,12 @@ class GoogleGeminiHandler:
|
|
23 |
VENDOR = 'google'
|
24 |
STARTING_TEMP = 0.5
|
25 |
|
26 |
-
def __init__(self, logger, model_name, JSON_dict_structure):
|
|
|
|
|
|
|
|
|
|
|
27 |
self.logger = logger
|
28 |
self.model_name = model_name
|
29 |
self.JSON_dict_structure = JSON_dict_structure
|
@@ -76,13 +78,13 @@ class GoogleGeminiHandler:
|
|
76 |
|
77 |
def _build_model_chain_parser(self):
|
78 |
# Instantiate the LLM class for Google Gemini
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
self.llm_model = VertexAI(model='gemini-pro',
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
# Set up the retry parser with the runnable
|
87 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
88 |
# Prepare the chain
|
@@ -90,10 +92,10 @@ class GoogleGeminiHandler:
|
|
90 |
|
91 |
# Define a function to format the input for Google Gemini call
|
92 |
def call_google_gemini(self, prompt_text):
|
93 |
-
model = GenerativeModel(self.model_name)
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
return response.text
|
98 |
|
99 |
def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
|
@@ -130,13 +132,9 @@ class GoogleGeminiHandler:
|
|
130 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
131 |
|
132 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
133 |
-
|
134 |
-
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
135 |
|
136 |
-
|
137 |
-
Wiki.gather_wikipedia_results(output)
|
138 |
-
|
139 |
-
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
140 |
|
141 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
142 |
|
@@ -156,6 +154,8 @@ class GoogleGeminiHandler:
|
|
156 |
|
157 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
158 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
|
|
|
|
159 |
|
160 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
161 |
self._reset_config()
|
|
|
6 |
# from langchain.schema import HumanMessage
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain_core.output_parsers import JsonOutputParser
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
from langchain_google_vertexai import VertexAI
|
11 |
|
12 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
13 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
|
|
|
|
14 |
|
15 |
class GoogleGeminiHandler:
|
16 |
|
|
|
20 |
VENDOR = 'google'
|
21 |
STARTING_TEMP = 0.5
|
22 |
|
23 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
24 |
+
self.cfg = cfg
|
25 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
26 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
27 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
28 |
+
|
29 |
self.logger = logger
|
30 |
self.model_name = model_name
|
31 |
self.JSON_dict_structure = JSON_dict_structure
|
|
|
78 |
|
79 |
def _build_model_chain_parser(self):
|
80 |
# Instantiate the LLM class for Google Gemini
|
81 |
+
self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)#,
|
82 |
+
# max_output_tokens=self.config.get('max_output_tokens'),
|
83 |
+
# top_p=self.config.get('top_p'))
|
84 |
+
# self.llm_model = VertexAI(model='gemini-1.0-pro',
|
85 |
+
# max_output_tokens=self.config.get('max_output_tokens'),
|
86 |
+
# top_p=self.config.get('top_p'))
|
87 |
+
|
88 |
# Set up the retry parser with the runnable
|
89 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
90 |
# Prepare the chain
|
|
|
92 |
|
93 |
# Define a function to format the input for Google Gemini call
|
94 |
def call_google_gemini(self, prompt_text):
|
95 |
+
model = GenerativeModel(self.model_name)#,
|
96 |
+
# generation_config=self.config,
|
97 |
+
# safety_settings=self.safety_settings)
|
98 |
+
response = model.generate_content(prompt_text.text)
|
99 |
return response.text
|
100 |
|
101 |
def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
|
|
|
132 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
133 |
|
134 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
135 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
|
|
136 |
|
137 |
+
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
|
|
|
|
|
|
138 |
|
139 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
140 |
|
|
|
154 |
|
155 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
156 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
157 |
+
|
158 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
159 |
|
160 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
161 |
self._reset_config()
|
vouchervision/LLM_GooglePalm2.py
CHANGED
@@ -11,11 +11,8 @@ from langchain_core.output_parsers import JsonOutputParser
|
|
11 |
# from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
13 |
|
14 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
15 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
16 |
-
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
17 |
-
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
18 |
-
from vouchervision.tool_wikipedia import WikipediaLinks
|
19 |
|
20 |
#https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
|
21 |
#pip install --upgrade google-cloud-aiplatform
|
@@ -34,7 +31,12 @@ class GooglePalm2Handler:
|
|
34 |
VENDOR = 'google'
|
35 |
STARTING_TEMP = 0.5
|
36 |
|
37 |
-
def __init__(self, logger, model_name, JSON_dict_structure):
|
|
|
|
|
|
|
|
|
|
|
38 |
self.logger = logger
|
39 |
self.model_name = model_name
|
40 |
self.JSON_dict_structure = JSON_dict_structure
|
@@ -144,13 +146,9 @@ class GooglePalm2Handler:
|
|
144 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
145 |
|
146 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
147 |
-
|
148 |
-
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
149 |
-
|
150 |
-
Wiki = WikipediaLinks(json_file_path_wiki)
|
151 |
-
Wiki.gather_wikipedia_results(output)
|
152 |
|
153 |
-
save_individual_prompt(
|
154 |
|
155 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
156 |
|
@@ -171,6 +169,7 @@ class GooglePalm2Handler:
|
|
171 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
172 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
173 |
|
|
|
174 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
175 |
self._reset_config()
|
176 |
|
|
|
11 |
# from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
13 |
|
14 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
15 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
|
|
|
|
16 |
|
17 |
#https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
|
18 |
#pip install --upgrade google-cloud-aiplatform
|
|
|
31 |
VENDOR = 'google'
|
32 |
STARTING_TEMP = 0.5
|
33 |
|
34 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
35 |
+
self.cfg = cfg
|
36 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
37 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
38 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
39 |
+
|
40 |
self.logger = logger
|
41 |
self.model_name = model_name
|
42 |
self.JSON_dict_structure = JSON_dict_structure
|
|
|
146 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
147 |
|
148 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
149 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
152 |
|
153 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
154 |
|
|
|
169 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
170 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
171 |
|
172 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
173 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
174 |
self._reset_config()
|
175 |
|
vouchervision/LLM_MistralAI.py
CHANGED
@@ -4,11 +4,8 @@ from langchain.output_parsers import RetryWithErrorOutputParser
|
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
|
7 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
8 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
9 |
-
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
10 |
-
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
11 |
-
from vouchervision.tool_wikipedia import WikipediaLinks
|
12 |
|
13 |
|
14 |
class MistralHandler:
|
@@ -19,7 +16,12 @@ class MistralHandler:
|
|
19 |
VENDOR = 'mistral'
|
20 |
RANDOM_SEED = 2023
|
21 |
|
22 |
-
def __init__(self, logger, model_name, JSON_dict_structure):
|
|
|
|
|
|
|
|
|
|
|
23 |
self.logger = logger
|
24 |
self.monitor = SystemLoadMonitor(logger)
|
25 |
self.has_GPU = torch.cuda.is_available()
|
@@ -115,13 +117,9 @@ class MistralHandler:
|
|
115 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
116 |
|
117 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
118 |
-
|
119 |
-
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
120 |
-
|
121 |
-
Wiki = WikipediaLinks(json_file_path_wiki)
|
122 |
-
Wiki.gather_wikipedia_results(output)
|
123 |
|
124 |
-
save_individual_prompt(
|
125 |
|
126 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
127 |
|
@@ -142,6 +140,7 @@ class MistralHandler:
|
|
142 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
143 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
144 |
|
|
|
145 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
146 |
self._reset_config()
|
147 |
json_report.set_text(text_main=f'LLM call failed')
|
|
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
|
7 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
8 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class MistralHandler:
|
|
|
16 |
VENDOR = 'mistral'
|
17 |
RANDOM_SEED = 2023
|
18 |
|
19 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
20 |
+
self.cfg = cfg
|
21 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
22 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
23 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
24 |
+
|
25 |
self.logger = logger
|
26 |
self.monitor = SystemLoadMonitor(logger)
|
27 |
self.has_GPU = torch.cuda.is_available()
|
|
|
117 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
118 |
|
119 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
120 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
123 |
|
124 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
125 |
|
|
|
140 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
141 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
142 |
|
143 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
144 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
145 |
self._reset_config()
|
146 |
json_report.set_text(text_main=f'LLM call failed')
|
vouchervision/LLM_OpenAI.py
CHANGED
@@ -5,11 +5,8 @@ from langchain.schema import HumanMessage
|
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
from langchain.output_parsers import RetryWithErrorOutputParser
|
7 |
|
8 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
9 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
10 |
-
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
11 |
-
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
12 |
-
from vouchervision.tool_wikipedia import WikipediaLinks
|
13 |
|
14 |
class OpenAIHandler:
|
15 |
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
@@ -18,7 +15,12 @@ class OpenAIHandler:
|
|
18 |
TOKENIZER_NAME = 'gpt-4'
|
19 |
VENDOR = 'openai'
|
20 |
|
21 |
-
def __init__(self, logger, model_name, JSON_dict_structure, is_azure, llm_object):
|
|
|
|
|
|
|
|
|
|
|
22 |
self.logger = logger
|
23 |
self.model_name = model_name
|
24 |
self.JSON_dict_structure = JSON_dict_structure
|
@@ -135,13 +137,14 @@ class OpenAIHandler:
|
|
135 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
136 |
|
137 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
138 |
-
|
139 |
-
|
140 |
|
141 |
-
|
142 |
-
|
|
|
143 |
|
144 |
-
save_individual_prompt(
|
145 |
|
146 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
147 |
|
@@ -162,6 +165,7 @@ class OpenAIHandler:
|
|
162 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
163 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
164 |
|
|
|
165 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
166 |
self._reset_config()
|
167 |
|
|
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
from langchain.output_parsers import RetryWithErrorOutputParser
|
7 |
|
8 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
9 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
|
|
|
|
10 |
|
11 |
class OpenAIHandler:
|
12 |
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
|
|
15 |
TOKENIZER_NAME = 'gpt-4'
|
16 |
VENDOR = 'openai'
|
17 |
|
18 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object):
|
19 |
+
self.cfg = cfg
|
20 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
21 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
22 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
23 |
+
|
24 |
self.logger = logger
|
25 |
self.model_name = model_name
|
26 |
self.JSON_dict_structure = JSON_dict_structure
|
|
|
137 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
138 |
|
139 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
140 |
+
|
141 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
142 |
|
143 |
+
# output1, WFO_record = validate_taxonomy_WFO(self.tool_WFO, output, replace_if_success_wfo=False)
|
144 |
+
# output2, GEO_record = validate_coordinates_here(self.tool_GEO, output, replace_if_success_geo=False)
|
145 |
+
# validate_wikipedia(self.tool_wikipedia, json_file_path_wiki, output)
|
146 |
|
147 |
+
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
148 |
|
149 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
150 |
|
|
|
165 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
166 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
167 |
|
168 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
169 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
170 |
self._reset_config()
|
171 |
|
vouchervision/LLM_local_MistralAI.py
CHANGED
@@ -6,11 +6,8 @@ from langchain_core.output_parsers import JsonOutputParser
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
|
9 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
10 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
11 |
-
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
12 |
-
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
13 |
-
from vouchervision.tool_wikipedia import WikipediaLinks
|
14 |
|
15 |
'''
|
16 |
Local Pipielines:
|
@@ -25,7 +22,12 @@ class LocalMistralHandler:
|
|
25 |
VENDOR = 'mistral'
|
26 |
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
27 |
|
28 |
-
def __init__(self, logger, model_name, JSON_dict_structure):
|
|
|
|
|
|
|
|
|
|
|
29 |
self.logger = logger
|
30 |
self.has_GPU = torch.cuda.is_available()
|
31 |
self.monitor = SystemLoadMonitor(logger)
|
@@ -188,13 +190,9 @@ class LocalMistralHandler:
|
|
188 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
189 |
|
190 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
191 |
-
|
192 |
-
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
193 |
-
|
194 |
-
Wiki = WikipediaLinks(json_file_path_wiki)
|
195 |
-
Wiki.gather_wikipedia_results(output)
|
196 |
|
197 |
-
save_individual_prompt(
|
198 |
|
199 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
200 |
|
@@ -214,6 +212,7 @@ class LocalMistralHandler:
|
|
214 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
215 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
216 |
|
|
|
217 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
218 |
json_report.set_text(text_main=f'LLM call failed')
|
219 |
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
|
9 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
10 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
|
|
|
|
11 |
|
12 |
'''
|
13 |
Local Pipielines:
|
|
|
22 |
VENDOR = 'mistral'
|
23 |
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
24 |
|
25 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
26 |
+
self.cfg = cfg
|
27 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
28 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
29 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
30 |
+
|
31 |
self.logger = logger
|
32 |
self.has_GPU = torch.cuda.is_available()
|
33 |
self.monitor = SystemLoadMonitor(logger)
|
|
|
190 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
191 |
|
192 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
193 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
|
|
|
|
|
|
|
|
194 |
|
195 |
+
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
196 |
|
197 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
198 |
|
|
|
212 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
213 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
214 |
|
215 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
216 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
217 |
json_report.set_text(text_main=f'LLM call failed')
|
218 |
|
vouchervision/LLM_local_cpu_MistralAI.py
CHANGED
@@ -18,11 +18,8 @@ from langchain.callbacks.base import BaseCallbackHandler
|
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
|
21 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
22 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
23 |
-
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
24 |
-
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
25 |
-
from vouchervision.tool_wikipedia import WikipediaLinks
|
26 |
|
27 |
class LocalCPUMistralHandler:
|
28 |
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
@@ -33,7 +30,12 @@ class LocalCPUMistralHandler:
|
|
33 |
SEED = 2023
|
34 |
|
35 |
|
36 |
-
def __init__(self, logger, model_name, JSON_dict_structure):
|
|
|
|
|
|
|
|
|
|
|
37 |
self.logger = logger
|
38 |
self.monitor = SystemLoadMonitor(logger)
|
39 |
self.has_GPU = torch.cuda.is_available()
|
@@ -179,13 +181,9 @@ class LocalCPUMistralHandler:
|
|
179 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
180 |
|
181 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
182 |
-
|
183 |
-
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
184 |
-
|
185 |
-
Wiki = WikipediaLinks(json_file_path_wiki)
|
186 |
-
Wiki.gather_wikipedia_results(output)
|
187 |
|
188 |
-
save_individual_prompt(
|
189 |
|
190 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
191 |
|
@@ -204,6 +202,7 @@ class LocalCPUMistralHandler:
|
|
204 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
205 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
206 |
|
|
|
207 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
208 |
self._reset_config()
|
209 |
|
|
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
|
21 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
22 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
|
|
|
|
23 |
|
24 |
class LocalCPUMistralHandler:
|
25 |
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
|
|
30 |
SEED = 2023
|
31 |
|
32 |
|
33 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
34 |
+
self.cfg = cfg
|
35 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
36 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
37 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
38 |
+
|
39 |
self.logger = logger
|
40 |
self.monitor = SystemLoadMonitor(logger)
|
41 |
self.has_GPU = torch.cuda.is_available()
|
|
|
181 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
182 |
|
183 |
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
184 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
|
|
|
|
|
|
|
|
185 |
|
186 |
+
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
187 |
|
188 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
189 |
|
|
|
202 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
203 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
204 |
|
205 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
206 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
207 |
self._reset_config()
|
208 |
|
vouchervision/OCR_Gemini.py
CHANGED
@@ -145,16 +145,16 @@ maximumElevationInMeters
|
|
145 |
}
|
146 |
"""
|
147 |
def _get_google_credentials():
|
148 |
-
with open('
|
149 |
data = json.load(file)
|
150 |
creds_json_str = json.dumps(data)
|
151 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
152 |
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
|
153 |
-
os.environ['GOOGLE_API_KEY'] = '
|
154 |
return credentials
|
155 |
|
156 |
if __name__ == '__main__':
|
157 |
-
vertexai.init(project='
|
158 |
|
159 |
logger = logging.getLogger('LLaVA')
|
160 |
logger.setLevel(logging.DEBUG)
|
|
|
145 |
}
|
146 |
"""
|
147 |
def _get_google_credentials():
|
148 |
+
with open('', 'r') as file:
|
149 |
data = json.load(file)
|
150 |
creds_json_str = json.dumps(data)
|
151 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
152 |
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
|
153 |
+
os.environ['GOOGLE_API_KEY'] = ''
|
154 |
return credentials
|
155 |
|
156 |
if __name__ == '__main__':
|
157 |
+
vertexai.init(project='', location='', credentials=_get_google_credentials())
|
158 |
|
159 |
logger = logging.getLogger('LLaVA')
|
160 |
logger.setLevel(logging.DEBUG)
|
vouchervision/OCR_google_cloud_vision.py
CHANGED
@@ -10,14 +10,6 @@ from google.oauth2 import service_account
|
|
10 |
|
11 |
### LLaVA should only be installed if the user will actually use it.
|
12 |
### It requires the most recent pytorch/Python and can mess with older systems
|
13 |
-
try:
|
14 |
-
from craft_text_detector import read_image, load_craftnet_model, load_refinenet_model, get_prediction, export_detected_regions, export_extra_results, empty_cuda_cache
|
15 |
-
except:
|
16 |
-
pass
|
17 |
-
try:
|
18 |
-
from OCR_llava import OCRllava
|
19 |
-
except:
|
20 |
-
pass
|
21 |
|
22 |
|
23 |
'''
|
@@ -92,9 +84,7 @@ class OCREngine:
|
|
92 |
|
93 |
self.multimodal_prompt = """I need you to transcribe all of the text in this image.
|
94 |
Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
|
95 |
-
|
96 |
-
if 'LLaVA' in self.OCR_option:
|
97 |
-
self.init_llava()
|
98 |
|
99 |
|
100 |
def set_client(self):
|
@@ -113,6 +103,8 @@ class OCREngine:
|
|
113 |
|
114 |
def init_craft(self):
|
115 |
if 'CRAFT' in self.OCR_option:
|
|
|
|
|
116 |
try:
|
117 |
self.refine_net = load_refinenet_model(cuda=True)
|
118 |
self.use_cuda = True
|
@@ -126,21 +118,23 @@ class OCREngine:
|
|
126 |
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
|
127 |
|
128 |
def init_llava(self):
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
|
133 |
-
|
134 |
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
|
143 |
-
|
144 |
|
145 |
def init_gemini_vision(self):
|
146 |
pass
|
@@ -150,6 +144,8 @@ class OCREngine:
|
|
150 |
|
151 |
|
152 |
def detect_text_craft(self):
|
|
|
|
|
153 |
# Perform prediction using CRAFT
|
154 |
image = read_image(self.path)
|
155 |
|
@@ -250,13 +246,13 @@ class OCREngine:
|
|
250 |
if not do_use_trOCR:
|
251 |
if 'normal' in self.OCR_option:
|
252 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
253 |
-
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
|
254 |
# ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
|
255 |
ocr_parts = self.normal_organized_text
|
256 |
|
257 |
if 'hand' in self.OCR_option:
|
258 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
259 |
-
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
|
260 |
# ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
|
261 |
ocr_parts = self.hand_organized_text
|
262 |
|
@@ -340,13 +336,13 @@ class OCREngine:
|
|
340 |
if 'normal' in self.OCR_option:
|
341 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
342 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
343 |
-
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
344 |
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
345 |
ocr_parts = self.trOCR_texts
|
346 |
if 'hand' in self.OCR_option:
|
347 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
348 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
349 |
-
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
350 |
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
351 |
ocr_parts = self.trOCR_texts
|
352 |
# if self.OCR_option in ['both',]:
|
@@ -358,7 +354,7 @@ class OCREngine:
|
|
358 |
if 'CRAFT' in self.OCR_option:
|
359 |
# self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
360 |
self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
|
361 |
-
logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
|
362 |
# ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
|
363 |
ocr_parts = self.trOCR_texts
|
364 |
return ocr_parts
|
@@ -383,7 +379,10 @@ class OCREngine:
|
|
383 |
|
384 |
for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
|
385 |
font_size = int(char_height)
|
386 |
-
|
|
|
|
|
|
|
387 |
if option == 'trOCR':
|
388 |
color = (0, 170, 255)
|
389 |
else:
|
@@ -686,7 +685,7 @@ class OCREngine:
|
|
686 |
self.OCR = self.OCR + part_OCR + part_OCR
|
687 |
else:
|
688 |
self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
689 |
-
logger.info(f"CRAFT trOCR:\n{self.OCR}")
|
690 |
|
691 |
if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
|
692 |
self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
|
@@ -704,25 +703,34 @@ class OCREngine:
|
|
704 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
|
705 |
else:
|
706 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
|
707 |
-
logger.info(f"LLaVA OCR:\n{self.OCR}")
|
708 |
|
709 |
if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
|
710 |
if 'normal' in self.OCR_option:
|
711 |
-
|
|
|
|
|
|
|
|
|
712 |
if 'hand' in self.OCR_option:
|
713 |
-
|
|
|
|
|
|
|
|
|
714 |
# if self.OCR_option not in ['normal', 'hand', 'both']:
|
715 |
# self.OCR_option = 'both'
|
716 |
# self.detect_text()
|
717 |
# self.detect_handwritten_ocr()
|
718 |
|
719 |
### Optionally add trOCR to the self.OCR for additional context
|
720 |
-
if self.
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
|
|
726 |
|
727 |
if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
|
728 |
self.image = Image.open(self.path)
|
@@ -744,8 +752,6 @@ class OCREngine:
|
|
744 |
image_with_boxes_normal = self.draw_boxes('normal')
|
745 |
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
|
746 |
|
747 |
-
|
748 |
-
|
749 |
### Merge final overlay image
|
750 |
### [original, normal bboxes, normal text]
|
751 |
if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
|
@@ -773,241 +779,7 @@ class OCREngine:
|
|
773 |
self.overlay_image = Image.open(self.path)
|
774 |
|
775 |
try:
|
|
|
776 |
empty_cuda_cache()
|
777 |
except:
|
778 |
-
pass
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
'''
|
783 |
-
BBOX_COLOR = "black" # green cyan
|
784 |
-
|
785 |
-
def render_text_on_black_image(image_path, handwritten_char_bounds_flat, handwritten_char_confidences, handwritten_char_heights, characters):
|
786 |
-
# Load the original image to get its dimensions
|
787 |
-
original_image = Image.open(image_path)
|
788 |
-
width, height = original_image.size
|
789 |
-
|
790 |
-
# Create a black image of the same size
|
791 |
-
black_image = Image.new("RGB", (width, height), "black")
|
792 |
-
draw = ImageDraw.Draw(black_image)
|
793 |
-
|
794 |
-
# Loop through each character
|
795 |
-
for bound, confidence, char_height, character in zip(handwritten_char_bounds_flat, handwritten_char_confidences, handwritten_char_heights, characters):
|
796 |
-
# Determine the font size based on the height of the character
|
797 |
-
font_size = int(char_height)
|
798 |
-
font = ImageFont.load_default().font_variant(size=font_size)
|
799 |
-
|
800 |
-
# Color of the character
|
801 |
-
color = confidence_to_color(confidence)
|
802 |
-
|
803 |
-
# Position of the text (using the bottom-left corner of the bounding box)
|
804 |
-
position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
|
805 |
-
|
806 |
-
# Draw the character
|
807 |
-
draw.text(position, character, fill=color, font=font)
|
808 |
-
|
809 |
-
return black_image
|
810 |
-
|
811 |
-
def merge_images(image1, image2):
|
812 |
-
# Assuming both images are of the same size
|
813 |
-
width, height = image1.size
|
814 |
-
merged_image = Image.new("RGB", (width * 2, height))
|
815 |
-
merged_image.paste(image1, (0, 0))
|
816 |
-
merged_image.paste(image2, (width, 0))
|
817 |
-
return merged_image
|
818 |
-
|
819 |
-
def draw_boxes(image, bounds, color):
|
820 |
-
if bounds:
|
821 |
-
draw = ImageDraw.Draw(image)
|
822 |
-
width, height = image.size
|
823 |
-
line_width = int((width + height) / 2 * 0.001) # This sets the line width as 0.5% of the average dimension
|
824 |
-
|
825 |
-
for bound in bounds:
|
826 |
-
draw.polygon(
|
827 |
-
[
|
828 |
-
bound["vertices"][0]["x"], bound["vertices"][0]["y"],
|
829 |
-
bound["vertices"][1]["x"], bound["vertices"][1]["y"],
|
830 |
-
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
831 |
-
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
832 |
-
],
|
833 |
-
outline=color,
|
834 |
-
width=line_width
|
835 |
-
)
|
836 |
-
return image
|
837 |
-
|
838 |
-
def detect_text(path):
|
839 |
-
client = vision.ImageAnnotatorClient()
|
840 |
-
with io.open(path, 'rb') as image_file:
|
841 |
-
content = image_file.read()
|
842 |
-
image = vision.Image(content=content)
|
843 |
-
response = client.document_text_detection(image=image)
|
844 |
-
texts = response.text_annotations
|
845 |
-
|
846 |
-
if response.error.message:
|
847 |
-
raise Exception(
|
848 |
-
'{}\nFor more info on error messages, check: '
|
849 |
-
'https://cloud.google.com/apis/design/errors'.format(
|
850 |
-
response.error.message))
|
851 |
-
|
852 |
-
# Extract bounding boxes
|
853 |
-
bounds = []
|
854 |
-
text_to_box_mapping = {}
|
855 |
-
for text in texts[1:]: # Skip the first entry, as it represents the entire detected text
|
856 |
-
# Convert BoundingPoly to dictionary
|
857 |
-
bound_dict = {
|
858 |
-
"vertices": [
|
859 |
-
{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices
|
860 |
-
]
|
861 |
-
}
|
862 |
-
bounds.append(bound_dict)
|
863 |
-
text_to_box_mapping[str(bound_dict)] = text.description
|
864 |
-
|
865 |
-
if texts:
|
866 |
-
# cleaned_text = texts[0].description.replace("\n", " ").replace("\t", " ").replace("|", " ")
|
867 |
-
cleaned_text = texts[0].description
|
868 |
-
return cleaned_text, bounds, text_to_box_mapping
|
869 |
-
else:
|
870 |
-
return '', None, None
|
871 |
-
|
872 |
-
def confidence_to_color(confidence):
|
873 |
-
"""Convert confidence level to a color ranging from red (low confidence) to green (high confidence)."""
|
874 |
-
# Using HSL color space, where Hue varies from red to green
|
875 |
-
hue = (confidence - 0.5) * 120 / 0.5 # Scale confidence to range 0-120 (red to green in HSL)
|
876 |
-
r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 1) # Convert to RGB
|
877 |
-
return (int(r*255), int(g*255), int(b*255))
|
878 |
-
|
879 |
-
def overlay_boxes_on_image(path, typed_bounds, handwritten_char_bounds, handwritten_char_confidences, do_create_OCR_helper_image):
|
880 |
-
if do_create_OCR_helper_image:
|
881 |
-
image = Image.open(path)
|
882 |
-
draw = ImageDraw.Draw(image)
|
883 |
-
width, height = image.size
|
884 |
-
line_width = int((width + height) / 2 * 0.005) # Adjust line width for character level
|
885 |
-
|
886 |
-
# Draw boxes for typed text
|
887 |
-
for bound in typed_bounds:
|
888 |
-
draw.polygon(
|
889 |
-
[
|
890 |
-
bound["vertices"][0]["x"], bound["vertices"][0]["y"],
|
891 |
-
bound["vertices"][1]["x"], bound["vertices"][1]["y"],
|
892 |
-
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
893 |
-
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
894 |
-
],
|
895 |
-
outline=BBOX_COLOR,
|
896 |
-
width=1
|
897 |
-
)
|
898 |
-
|
899 |
-
# Draw a line segment at the bottom of each handwritten character
|
900 |
-
for bound, confidence in zip(handwritten_char_bounds, handwritten_char_confidences):
|
901 |
-
color = confidence_to_color(confidence)
|
902 |
-
# Use the bottom two vertices of the bounding box for the line
|
903 |
-
bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width)
|
904 |
-
bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width)
|
905 |
-
draw.line([bottom_left, bottom_right], fill=color, width=line_width)
|
906 |
-
|
907 |
-
text_image = render_text_on_black_image(path, handwritten_char_bounds, handwritten_char_confidences)
|
908 |
-
merged_image = merge_images(image, text_image) # Assuming 'overlayed_image' is the image with lines
|
909 |
-
|
910 |
-
|
911 |
-
return merged_image
|
912 |
-
else:
|
913 |
-
return Image.open(path)
|
914 |
-
|
915 |
-
def detect_handwritten_ocr(path):
|
916 |
-
"""Detects handwritten characters in a local image and returns their bounding boxes and confidence levels.
|
917 |
-
|
918 |
-
Args:
|
919 |
-
path: The path to the local file.
|
920 |
-
|
921 |
-
Returns:
|
922 |
-
A tuple of (text, bounding_boxes, confidences)
|
923 |
-
"""
|
924 |
-
client = vision_beta.ImageAnnotatorClient()
|
925 |
-
|
926 |
-
with open(path, "rb") as image_file:
|
927 |
-
content = image_file.read()
|
928 |
-
|
929 |
-
image = vision_beta.Image(content=content)
|
930 |
-
image_context = vision_beta.ImageContext(language_hints=["en-t-i0-handwrit"])
|
931 |
-
response = client.document_text_detection(image=image, image_context=image_context)
|
932 |
-
|
933 |
-
if response.error.message:
|
934 |
-
raise Exception(
|
935 |
-
"{}\nFor more info on error messages, check: "
|
936 |
-
"https://cloud.google.com/apis/design/errors".format(response.error.message)
|
937 |
-
)
|
938 |
-
|
939 |
-
bounds = []
|
940 |
-
bounds_flat = []
|
941 |
-
height_flat = []
|
942 |
-
confidences = []
|
943 |
-
character = []
|
944 |
-
for page in response.full_text_annotation.pages:
|
945 |
-
for block in page.blocks:
|
946 |
-
for paragraph in block.paragraphs:
|
947 |
-
for word in paragraph.words:
|
948 |
-
# Get the bottom Y-location (max Y) for the whole word
|
949 |
-
Y = max(vertex.y for vertex in word.bounding_box.vertices)
|
950 |
-
|
951 |
-
# Get the height of the word's bounding box
|
952 |
-
H = Y - min(vertex.y for vertex in word.bounding_box.vertices)
|
953 |
-
|
954 |
-
for symbol in word.symbols:
|
955 |
-
# Collecting bounding box for each symbol
|
956 |
-
bound_dict = {
|
957 |
-
"vertices": [
|
958 |
-
{"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
|
959 |
-
]
|
960 |
-
}
|
961 |
-
bounds.append(bound_dict)
|
962 |
-
|
963 |
-
# Bounds with same bottom y height
|
964 |
-
bounds_flat_dict = {
|
965 |
-
"vertices": [
|
966 |
-
{"x": vertex.x, "y": Y} for vertex in symbol.bounding_box.vertices
|
967 |
-
]
|
968 |
-
}
|
969 |
-
bounds_flat.append(bounds_flat_dict)
|
970 |
-
|
971 |
-
# Add the word's height
|
972 |
-
height_flat.append(H)
|
973 |
-
|
974 |
-
# Collecting confidence for each symbol
|
975 |
-
symbol_confidence = round(symbol.confidence, 4)
|
976 |
-
confidences.append(symbol_confidence)
|
977 |
-
character.append(symbol.text)
|
978 |
-
|
979 |
-
cleaned_text = response.full_text_annotation.text
|
980 |
-
|
981 |
-
return cleaned_text, bounds, bounds_flat, height_flat, confidences, character
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
def process_image(path, do_create_OCR_helper_image):
|
986 |
-
typed_text, typed_bounds, _ = detect_text(path)
|
987 |
-
handwritten_text, handwritten_bounds, _ = detect_handwritten_ocr(path)
|
988 |
-
|
989 |
-
overlayed_image = overlay_boxes_on_image(path, typed_bounds, handwritten_bounds, do_create_OCR_helper_image)
|
990 |
-
return typed_text, handwritten_text, overlayed_image
|
991 |
-
|
992 |
-
'''
|
993 |
-
|
994 |
-
# ''' Google Vision'''
|
995 |
-
# def detect_text(path):
|
996 |
-
# """Detects text in the file located in the local filesystem."""
|
997 |
-
# client = vision.ImageAnnotatorClient()
|
998 |
-
|
999 |
-
# with io.open(path, 'rb') as image_file:
|
1000 |
-
# content = image_file.read()
|
1001 |
-
|
1002 |
-
# image = vision.Image(content=content)
|
1003 |
-
|
1004 |
-
# response = client.document_text_detection(image=image)
|
1005 |
-
# texts = response.text_annotations
|
1006 |
-
|
1007 |
-
# if response.error.message:
|
1008 |
-
# raise Exception(
|
1009 |
-
# '{}\nFor more info on error messages, check: '
|
1010 |
-
# 'https://cloud.google.com/apis/design/errors'.format(
|
1011 |
-
# response.error.message))
|
1012 |
-
|
1013 |
-
# return texts[0].description if texts else ''
|
|
|
10 |
|
11 |
### LLaVA should only be installed if the user will actually use it.
|
12 |
### It requires the most recent pytorch/Python and can mess with older systems
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
'''
|
|
|
84 |
|
85 |
self.multimodal_prompt = """I need you to transcribe all of the text in this image.
|
86 |
Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
|
87 |
+
self.init_llava()
|
|
|
|
|
88 |
|
89 |
|
90 |
def set_client(self):
|
|
|
103 |
|
104 |
def init_craft(self):
|
105 |
if 'CRAFT' in self.OCR_option:
|
106 |
+
from craft_text_detector import load_craftnet_model, load_refinenet_model
|
107 |
+
|
108 |
try:
|
109 |
self.refine_net = load_refinenet_model(cuda=True)
|
110 |
self.use_cuda = True
|
|
|
118 |
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
|
119 |
|
120 |
def init_llava(self):
|
121 |
+
if 'LLaVA' in self.OCR_option:
|
122 |
+
from vouchervision.OCR_llava import OCRllava
|
123 |
|
124 |
+
self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
|
125 |
+
self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
|
126 |
|
127 |
+
self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
|
128 |
|
129 |
+
if self.model_quant == '4bit':
|
130 |
+
use_4bit = True
|
131 |
+
elif self.model_quant == 'full':
|
132 |
+
use_4bit = False
|
133 |
+
else:
|
134 |
+
self.logger.info(f"Provided model quantization invlid. Using 4bit.")
|
135 |
+
use_4bit = True
|
136 |
|
137 |
+
self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
|
138 |
|
139 |
def init_gemini_vision(self):
|
140 |
pass
|
|
|
144 |
|
145 |
|
146 |
def detect_text_craft(self):
|
147 |
+
from craft_text_detector import read_image, get_prediction
|
148 |
+
|
149 |
# Perform prediction using CRAFT
|
150 |
image = read_image(self.path)
|
151 |
|
|
|
246 |
if not do_use_trOCR:
|
247 |
if 'normal' in self.OCR_option:
|
248 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
249 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
|
250 |
# ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
|
251 |
ocr_parts = self.normal_organized_text
|
252 |
|
253 |
if 'hand' in self.OCR_option:
|
254 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
255 |
+
# logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
|
256 |
# ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
|
257 |
ocr_parts = self.hand_organized_text
|
258 |
|
|
|
336 |
if 'normal' in self.OCR_option:
|
337 |
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
338 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
339 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
340 |
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
341 |
ocr_parts = self.trOCR_texts
|
342 |
if 'hand' in self.OCR_option:
|
343 |
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
344 |
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
345 |
+
# logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
346 |
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
347 |
ocr_parts = self.trOCR_texts
|
348 |
# if self.OCR_option in ['both',]:
|
|
|
354 |
if 'CRAFT' in self.OCR_option:
|
355 |
# self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
356 |
self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
|
357 |
+
# logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
|
358 |
# ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
|
359 |
ocr_parts = self.trOCR_texts
|
360 |
return ocr_parts
|
|
|
379 |
|
380 |
for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
|
381 |
font_size = int(char_height)
|
382 |
+
try:
|
383 |
+
font = ImageFont.truetype("arial.ttf", font_size)
|
384 |
+
except:
|
385 |
+
font = ImageFont.load_default().font_variant(size=font_size)
|
386 |
if option == 'trOCR':
|
387 |
color = (0, 170, 255)
|
388 |
else:
|
|
|
685 |
self.OCR = self.OCR + part_OCR + part_OCR
|
686 |
else:
|
687 |
self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
688 |
+
# logger.info(f"CRAFT trOCR:\n{self.OCR}")
|
689 |
|
690 |
if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
|
691 |
self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
|
|
|
703 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
|
704 |
else:
|
705 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
|
706 |
+
# logger.info(f"LLaVA OCR:\n{self.OCR}")
|
707 |
|
708 |
if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
|
709 |
if 'normal' in self.OCR_option:
|
710 |
+
if self.double_OCR:
|
711 |
+
part_OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
|
712 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
713 |
+
else:
|
714 |
+
self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
|
715 |
if 'hand' in self.OCR_option:
|
716 |
+
if self.double_OCR:
|
717 |
+
part_OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
|
718 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
719 |
+
else:
|
720 |
+
self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
|
721 |
# if self.OCR_option not in ['normal', 'hand', 'both']:
|
722 |
# self.OCR_option = 'both'
|
723 |
# self.detect_text()
|
724 |
# self.detect_handwritten_ocr()
|
725 |
|
726 |
### Optionally add trOCR to the self.OCR for additional context
|
727 |
+
if self.do_use_trOCR:
|
728 |
+
if self.double_OCR:
|
729 |
+
part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
730 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
731 |
+
else:
|
732 |
+
self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
733 |
+
# logger.info(f"OCR:\n{self.OCR}")
|
734 |
|
735 |
if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
|
736 |
self.image = Image.open(self.path)
|
|
|
752 |
image_with_boxes_normal = self.draw_boxes('normal')
|
753 |
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
|
754 |
|
|
|
|
|
755 |
### Merge final overlay image
|
756 |
### [original, normal bboxes, normal text]
|
757 |
if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
|
|
|
779 |
self.overlay_image = Image.open(self.path)
|
780 |
|
781 |
try:
|
782 |
+
from craft_text_detector import empty_cuda_cache
|
783 |
empty_cuda_cache()
|
784 |
except:
|
785 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vouchervision/OCR_llava.py
CHANGED
@@ -3,20 +3,20 @@ import requests
|
|
3 |
from PIL import Image
|
4 |
from io import BytesIO
|
5 |
import torch
|
6 |
-
from transformers import AutoTokenizer, BitsAndBytesConfig, TextStreamer
|
7 |
|
8 |
-
from langchain.prompts import PromptTemplate
|
9 |
from langchain_core.output_parsers import JsonOutputParser
|
10 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
11 |
|
12 |
-
from LLaVA.llava.model import LlavaLlamaForCausalLM
|
13 |
-
from LLaVA.llava.model.builder import load_pretrained_model
|
14 |
-
from LLaVA.llava.conversation import conv_templates
|
15 |
-
from LLaVA.llava.utils import disable_torch_init
|
16 |
-
from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
|
17 |
-
from LLaVA.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
18 |
|
19 |
-
from utils_LLM import SystemLoadMonitor
|
20 |
|
21 |
'''
|
22 |
Performance expectations system:
|
|
|
3 |
from PIL import Image
|
4 |
from io import BytesIO
|
5 |
import torch
|
6 |
+
# from transformers import AutoTokenizer, BitsAndBytesConfig, TextStreamer
|
7 |
|
8 |
+
# from langchain.prompts import PromptTemplate
|
9 |
from langchain_core.output_parsers import JsonOutputParser
|
10 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
11 |
|
12 |
+
# from vouchervision.LLaVA.llava.model import LlavaLlamaForCausalLM
|
13 |
+
from vouchervision.LLaVA.llava.model.builder import load_pretrained_model
|
14 |
+
from vouchervision.LLaVA.llava.conversation import conv_templates#, SeparatorStyle
|
15 |
+
from vouchervision.LLaVA.llava.utils import disable_torch_init
|
16 |
+
from vouchervision.LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
|
17 |
+
from vouchervision.LLaVA.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images #KeywordsStoppingCriteria
|
18 |
|
19 |
+
from vouchervision.utils_LLM import SystemLoadMonitor
|
20 |
|
21 |
'''
|
22 |
Performance expectations system:
|
vouchervision/VoucherVision_Config_Builder.py
CHANGED
@@ -36,16 +36,22 @@ def build_VV_config(loaded_cfg=None):
|
|
36 |
save_cropped_annotations = ['label','barcode']
|
37 |
|
38 |
do_use_trOCR = False
|
|
|
39 |
OCR_option = 'hand'
|
40 |
OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
|
41 |
OCR_option_llava_bit = 'full' # full or 4bit
|
42 |
double_OCR = False
|
43 |
|
|
|
|
|
|
|
|
|
|
|
44 |
check_for_illegal_filenames = False
|
45 |
|
46 |
LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
|
47 |
-
prompt_version = '
|
48 |
-
use_LeafMachine2_collage_images =
|
49 |
do_create_OCR_helper_image = True
|
50 |
|
51 |
batch_size = 500
|
@@ -54,8 +60,8 @@ def build_VV_config(loaded_cfg=None):
|
|
54 |
skip_vertical = False
|
55 |
pdf_conversion_dpi = 100
|
56 |
|
57 |
-
path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
|
58 |
-
embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
|
59 |
|
60 |
#############################################
|
61 |
#############################################
|
@@ -65,7 +71,9 @@ def build_VV_config(loaded_cfg=None):
|
|
65 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
66 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
67 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
68 |
-
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava,
|
|
|
|
|
69 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
70 |
else:
|
71 |
dir_home = os.path.dirname(os.path.dirname(__file__))
|
@@ -80,11 +88,16 @@ def build_VV_config(loaded_cfg=None):
|
|
80 |
catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
|
81 |
|
82 |
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
|
|
83 |
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
84 |
OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
|
85 |
OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
|
86 |
double_OCR = loaded_cfg['leafmachine']['project']['double_OCR']
|
87 |
|
|
|
|
|
|
|
|
|
88 |
pdf_conversion_dpi = loaded_cfg['leafmachine']['project']['pdf_conversion_dpi']
|
89 |
|
90 |
LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
|
@@ -105,14 +118,18 @@ def build_VV_config(loaded_cfg=None):
|
|
105 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
106 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
107 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
108 |
-
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava,
|
|
|
|
|
109 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
110 |
|
111 |
|
112 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
113 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
114 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
115 |
-
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, OCR_option_llava,
|
|
|
|
|
116 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
|
117 |
|
118 |
|
@@ -157,11 +174,15 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
157 |
'delete_all_temps': False,
|
158 |
'delete_temps_keep_VVE': False,
|
159 |
'do_use_trOCR': do_use_trOCR,
|
|
|
160 |
'OCR_option': OCR_option,
|
161 |
'OCR_option_llava': OCR_option_llava,
|
162 |
'OCR_option_llava_bit': OCR_option_llava_bit,
|
163 |
'double_OCR': double_OCR,
|
164 |
'pdf_conversion_dpi': pdf_conversion_dpi,
|
|
|
|
|
|
|
165 |
}
|
166 |
|
167 |
modules_section = {
|
|
|
36 |
save_cropped_annotations = ['label','barcode']
|
37 |
|
38 |
do_use_trOCR = False
|
39 |
+
trOCR_model_path = "microsoft/trocr-large-handwritten"
|
40 |
OCR_option = 'hand'
|
41 |
OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
|
42 |
OCR_option_llava_bit = 'full' # full or 4bit
|
43 |
double_OCR = False
|
44 |
|
45 |
+
|
46 |
+
tool_GEO = True
|
47 |
+
tool_WFO = True
|
48 |
+
tool_wikipedia = True
|
49 |
+
|
50 |
check_for_illegal_filenames = False
|
51 |
|
52 |
LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
|
53 |
+
prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
54 |
+
use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
|
55 |
do_create_OCR_helper_image = True
|
56 |
|
57 |
batch_size = 500
|
|
|
60 |
skip_vertical = False
|
61 |
pdf_conversion_dpi = 100
|
62 |
|
63 |
+
path_domain_knowledge = '' #os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
|
64 |
+
embeddings_database_name = '' #os.path.splitext(os.path.basename(path_domain_knowledge))[0]
|
65 |
|
66 |
#############################################
|
67 |
#############################################
|
|
|
71 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
72 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
73 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
74 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
|
75 |
+
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
76 |
+
tool_GEO, tool_WFO, tool_wikipedia,
|
77 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
78 |
else:
|
79 |
dir_home = os.path.dirname(os.path.dirname(__file__))
|
|
|
88 |
catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
|
89 |
|
90 |
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
91 |
+
trOCR_model_path = loaded_cfg['leafmachine']['project']['trOCR_model_path']
|
92 |
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
93 |
OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
|
94 |
OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
|
95 |
double_OCR = loaded_cfg['leafmachine']['project']['double_OCR']
|
96 |
|
97 |
+
tool_GEO = loaded_cfg['leafmachine']['project']['tool_GEO']
|
98 |
+
tool_WFO = loaded_cfg['leafmachine']['project']['tool_WFO']
|
99 |
+
tool_wikipedia = loaded_cfg['leafmachine']['project']['tool_wikipedia']
|
100 |
+
|
101 |
pdf_conversion_dpi = loaded_cfg['leafmachine']['project']['pdf_conversion_dpi']
|
102 |
|
103 |
LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
|
|
|
118 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
119 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
120 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
121 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
|
122 |
+
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
123 |
+
tool_GEO, tool_WFO, tool_wikipedia,
|
124 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
125 |
|
126 |
|
127 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
128 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
129 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
130 |
+
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
|
131 |
+
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
132 |
+
tool_GEO, tool_WFO, tool_wikipedia,
|
133 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
|
134 |
|
135 |
|
|
|
174 |
'delete_all_temps': False,
|
175 |
'delete_temps_keep_VVE': False,
|
176 |
'do_use_trOCR': do_use_trOCR,
|
177 |
+
'trOCR_model_path': trOCR_model_path,
|
178 |
'OCR_option': OCR_option,
|
179 |
'OCR_option_llava': OCR_option_llava,
|
180 |
'OCR_option_llava_bit': OCR_option_llava_bit,
|
181 |
'double_OCR': double_OCR,
|
182 |
'pdf_conversion_dpi': pdf_conversion_dpi,
|
183 |
+
'tool_GEO': tool_GEO,
|
184 |
+
'tool_WFO': tool_WFO,
|
185 |
+
'tool_wikipedia': tool_wikipedia,
|
186 |
}
|
187 |
|
188 |
modules_section = {
|
vouchervision/model_maps.py
CHANGED
@@ -206,7 +206,7 @@ class ModelMaps:
|
|
206 |
return "text-unicorn@001"
|
207 |
|
208 |
elif key == 'GEMINI_PRO':
|
209 |
-
return "gemini-pro"
|
210 |
|
211 |
### Mistral
|
212 |
elif key == 'MISTRAL_TINY':
|
|
|
206 |
return "text-unicorn@001"
|
207 |
|
208 |
elif key == 'GEMINI_PRO':
|
209 |
+
return "gemini-1.0-pro"
|
210 |
|
211 |
### Mistral
|
212 |
elif key == 'MISTRAL_TINY':
|
vouchervision/tool_geolocate_HERE.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, requests
|
2 |
+
import pycountry_convert as pc
|
3 |
+
import unicodedata
|
4 |
+
import pycountry_convert as pc
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_country_name(name):
|
9 |
+
return unicodedata.normalize('NFKD', name).encode('ASCII', 'ignore').decode('ASCII')
|
10 |
+
|
11 |
+
def get_continent(country_name):
|
12 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='pycountry')
|
13 |
+
|
14 |
+
continent_code_to_name = {
|
15 |
+
"AF": "Africa",
|
16 |
+
"NA": "North America",
|
17 |
+
"OC": "Oceania",
|
18 |
+
"AN": "Antarctica",
|
19 |
+
"AS": "Asia",
|
20 |
+
"EU": "Europe",
|
21 |
+
"SA": "South America"
|
22 |
+
}
|
23 |
+
|
24 |
+
try:
|
25 |
+
normalized_country_name = normalize_country_name(country_name)
|
26 |
+
# Get country alpha2 code
|
27 |
+
country_code = pc.country_name_to_country_alpha2(normalized_country_name)
|
28 |
+
# Get continent code from country alpha2 code
|
29 |
+
continent_code = pc.country_alpha2_to_continent_code(country_code)
|
30 |
+
# Map the continent code to continent name
|
31 |
+
return continent_code_to_name.get(continent_code, '')
|
32 |
+
except Exception as e:
|
33 |
+
print(str(e))
|
34 |
+
return ''
|
35 |
+
|
36 |
+
def validate_coordinates_here(tool_GEO, record, replace_if_success_geo=False):
|
37 |
+
forward_url = 'https://geocode.search.hereapi.com/v1/geocode'
|
38 |
+
reverse_url = 'https://revgeocode.search.hereapi.com/v1/revgeocode'
|
39 |
+
|
40 |
+
pinpoint = ['GEO_city','GEO_county','GEO_state','GEO_country',]
|
41 |
+
GEO_dict_null = {
|
42 |
+
'GEO_override_OCR': False,
|
43 |
+
'GEO_method': '',
|
44 |
+
'GEO_formatted_full_string': '',
|
45 |
+
'GEO_decimal_lat': '',
|
46 |
+
'GEO_decimal_long': '',
|
47 |
+
'GEO_city': '',
|
48 |
+
'GEO_county': '',
|
49 |
+
'GEO_state': '',
|
50 |
+
'GEO_state_code': '',
|
51 |
+
'GEO_country': '',
|
52 |
+
'GEO_country_code': '',
|
53 |
+
'GEO_continent': '',
|
54 |
+
}
|
55 |
+
GEO_dict = {
|
56 |
+
'GEO_override_OCR': False,
|
57 |
+
'GEO_method': '',
|
58 |
+
'GEO_formatted_full_string': '',
|
59 |
+
'GEO_decimal_lat': '',
|
60 |
+
'GEO_decimal_long': '',
|
61 |
+
'GEO_city': '',
|
62 |
+
'GEO_county': '',
|
63 |
+
'GEO_state': '',
|
64 |
+
'GEO_state_code': '',
|
65 |
+
'GEO_country': '',
|
66 |
+
'GEO_country_code': '',
|
67 |
+
'GEO_continent': '',
|
68 |
+
}
|
69 |
+
GEO_dict_rev = {
|
70 |
+
'GEO_override_OCR': False,
|
71 |
+
'GEO_method': '',
|
72 |
+
'GEO_formatted_full_string': '',
|
73 |
+
'GEO_decimal_lat': '',
|
74 |
+
'GEO_decimal_long': '',
|
75 |
+
'GEO_city': '',
|
76 |
+
'GEO_county': '',
|
77 |
+
'GEO_state': '',
|
78 |
+
'GEO_state_code': '',
|
79 |
+
'GEO_country': '',
|
80 |
+
'GEO_country_code': '',
|
81 |
+
'GEO_continent': '',
|
82 |
+
}
|
83 |
+
GEO_dict_rev_verbatim = {
|
84 |
+
'GEO_override_OCR': False,
|
85 |
+
'GEO_method': '',
|
86 |
+
'GEO_formatted_full_string': '',
|
87 |
+
'GEO_decimal_lat': '',
|
88 |
+
'GEO_decimal_long': '',
|
89 |
+
'GEO_city': '',
|
90 |
+
'GEO_county': '',
|
91 |
+
'GEO_state': '',
|
92 |
+
'GEO_state_code': '',
|
93 |
+
'GEO_country': '',
|
94 |
+
'GEO_country_code': '',
|
95 |
+
'GEO_continent': '',
|
96 |
+
}
|
97 |
+
GEO_dict_forward = {
|
98 |
+
'GEO_override_OCR': False,
|
99 |
+
'GEO_method': '',
|
100 |
+
'GEO_formatted_full_string': '',
|
101 |
+
'GEO_decimal_lat': '',
|
102 |
+
'GEO_decimal_long': '',
|
103 |
+
'GEO_city': '',
|
104 |
+
'GEO_county': '',
|
105 |
+
'GEO_state': '',
|
106 |
+
'GEO_state_code': '',
|
107 |
+
'GEO_country': '',
|
108 |
+
'GEO_country_code': '',
|
109 |
+
'GEO_continent': '',
|
110 |
+
}
|
111 |
+
GEO_dict_forward_locality = {
|
112 |
+
'GEO_override_OCR': False,
|
113 |
+
'GEO_method': '',
|
114 |
+
'GEO_formatted_full_string': '',
|
115 |
+
'GEO_decimal_lat': '',
|
116 |
+
'GEO_decimal_long': '',
|
117 |
+
'GEO_city': '',
|
118 |
+
'GEO_county': '',
|
119 |
+
'GEO_state': '',
|
120 |
+
'GEO_state_code': '',
|
121 |
+
'GEO_country': '',
|
122 |
+
'GEO_country_code': '',
|
123 |
+
'GEO_continent': '',
|
124 |
+
}
|
125 |
+
|
126 |
+
if not tool_GEO:
|
127 |
+
return record, GEO_dict_null
|
128 |
+
else:
|
129 |
+
# For production
|
130 |
+
query_forward = ', '.join(filter(None, [record.get('municipality', '').strip(),
|
131 |
+
record.get('county', '').strip(),
|
132 |
+
record.get('stateProvince', '').strip(),
|
133 |
+
record.get('country', '').strip()])).strip()
|
134 |
+
query_forward_locality = ', '.join(filter(None, [record.get('locality', '').strip(),
|
135 |
+
record.get('municipality', '').strip(),
|
136 |
+
record.get('county', '').strip(),
|
137 |
+
record.get('stateProvince', '').strip(),
|
138 |
+
record.get('country', '').strip()])).strip()
|
139 |
+
query_reverse = ','.join(filter(None, [record.get('decimalLatitude', '').strip(),
|
140 |
+
record.get('decimalLongitude', '').strip()])).strip()
|
141 |
+
query_reverse_verbatim = record.get('verbatimCoordinates', '').strip()
|
142 |
+
|
143 |
+
|
144 |
+
'''
|
145 |
+
#For testing
|
146 |
+
# query_forward = 'Ann bor, michign'
|
147 |
+
query_forward = 'michigan'
|
148 |
+
query_forward_locality = 'Ann bor, michign'
|
149 |
+
# query_gps = "42 N,-83 W" # cannot have any spaces
|
150 |
+
# query_reverse_verbatim = "42.278366,-83.744718" # cannot have any spaces
|
151 |
+
query_reverse_verbatim = "42,-83" # cannot have any spaces
|
152 |
+
query_reverse = "42,-83" # cannot have any spaces
|
153 |
+
# params = {
|
154 |
+
# 'q': query_loc,
|
155 |
+
# 'apiKey': os.environ['HERE_API_KEY'],
|
156 |
+
# }'''
|
157 |
+
|
158 |
+
|
159 |
+
params_rev = {
|
160 |
+
'at': query_reverse,
|
161 |
+
'apiKey': os.environ['HERE_API_KEY'],
|
162 |
+
'lang': 'en',
|
163 |
+
}
|
164 |
+
params_reverse_verbatim = {
|
165 |
+
'at': query_reverse_verbatim,
|
166 |
+
'apiKey': os.environ['HERE_API_KEY'],
|
167 |
+
'lang': 'en',
|
168 |
+
}
|
169 |
+
params_forward = {
|
170 |
+
'q': query_forward,
|
171 |
+
'apiKey': os.environ['HERE_API_KEY'],
|
172 |
+
'lang': 'en',
|
173 |
+
}
|
174 |
+
params_forward_locality = {
|
175 |
+
'q': query_forward_locality,
|
176 |
+
'apiKey': os.environ['HERE_API_KEY'],
|
177 |
+
'lang': 'en',
|
178 |
+
}
|
179 |
+
|
180 |
+
### REVERSE
|
181 |
+
# If there are two string in the coordinates, try a reverse first based on the literal coordinates
|
182 |
+
response = requests.get(reverse_url, params=params_rev)
|
183 |
+
if response.status_code == 200:
|
184 |
+
data = response.json()
|
185 |
+
if data.get('items'):
|
186 |
+
first_result = data['items'][0]
|
187 |
+
GEO_dict_rev['GEO_method'] = 'HERE_Geocode_reverse'
|
188 |
+
GEO_dict_rev['GEO_formatted_full_string'] = first_result.get('title', '')
|
189 |
+
GEO_dict_rev['GEO_decimal_lat'] = first_result['position']['lat']
|
190 |
+
GEO_dict_rev['GEO_decimal_long'] = first_result['position']['lng']
|
191 |
+
|
192 |
+
address = first_result.get('address', {})
|
193 |
+
GEO_dict_rev['GEO_city'] = address.get('city', '')
|
194 |
+
GEO_dict_rev['GEO_county'] = address.get('county', '')
|
195 |
+
GEO_dict_rev['GEO_state'] = address.get('state', '')
|
196 |
+
GEO_dict_rev['GEO_state_code'] = address.get('stateCode', '')
|
197 |
+
GEO_dict_rev['GEO_country'] = address.get('countryName', '')
|
198 |
+
GEO_dict_rev['GEO_country_code'] = address.get('countryCode', '')
|
199 |
+
GEO_dict_rev['GEO_continent'] = get_continent(address.get('countryName', ''))
|
200 |
+
|
201 |
+
### REVERSE Verbatim
|
202 |
+
# If there are two string in the coordinates, try a reverse first based on the literal coordinates
|
203 |
+
if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
|
204 |
+
GEO_dict = GEO_dict_rev
|
205 |
+
else:
|
206 |
+
response = requests.get(reverse_url, params=params_reverse_verbatim)
|
207 |
+
if response.status_code == 200:
|
208 |
+
data = response.json()
|
209 |
+
if data.get('items'):
|
210 |
+
first_result = data['items'][0]
|
211 |
+
GEO_dict_rev_verbatim['GEO_method'] = 'HERE_Geocode_reverse_verbatimCoordinates'
|
212 |
+
GEO_dict_rev_verbatim['GEO_formatted_full_string'] = first_result.get('title', '')
|
213 |
+
GEO_dict_rev_verbatim['GEO_decimal_lat'] = first_result['position']['lat']
|
214 |
+
GEO_dict_rev_verbatim['GEO_decimal_long'] = first_result['position']['lng']
|
215 |
+
|
216 |
+
address = first_result.get('address', {})
|
217 |
+
GEO_dict_rev_verbatim['GEO_city'] = address.get('city', '')
|
218 |
+
GEO_dict_rev_verbatim['GEO_county'] = address.get('county', '')
|
219 |
+
GEO_dict_rev_verbatim['GEO_state'] = address.get('state', '')
|
220 |
+
GEO_dict_rev_verbatim['GEO_state_code'] = address.get('stateCode', '')
|
221 |
+
GEO_dict_rev_verbatim['GEO_country'] = address.get('countryName', '')
|
222 |
+
GEO_dict_rev_verbatim['GEO_country_code'] = address.get('countryCode', '')
|
223 |
+
GEO_dict_rev_verbatim['GEO_continent'] = get_continent(address.get('countryName', ''))
|
224 |
+
|
225 |
+
### FORWARD
|
226 |
+
### Try forward, if failes, try reverse using deci, then verbatim
|
227 |
+
if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
|
228 |
+
GEO_dict = GEO_dict_rev
|
229 |
+
elif GEO_dict_rev_verbatim['GEO_city']:
|
230 |
+
GEO_dict = GEO_dict_rev_verbatim
|
231 |
+
else:
|
232 |
+
response = requests.get(forward_url, params=params_forward)
|
233 |
+
if response.status_code == 200:
|
234 |
+
data = response.json()
|
235 |
+
if data.get('items'):
|
236 |
+
first_result = data['items'][0]
|
237 |
+
GEO_dict_forward['GEO_method'] = 'HERE_Geocode_forward'
|
238 |
+
GEO_dict_forward['GEO_formatted_full_string'] = first_result.get('title', '')
|
239 |
+
GEO_dict_forward['GEO_decimal_lat'] = first_result['position']['lat']
|
240 |
+
GEO_dict_forward['GEO_decimal_long'] = first_result['position']['lng']
|
241 |
+
|
242 |
+
address = first_result.get('address', {})
|
243 |
+
GEO_dict_forward['GEO_city'] = address.get('city', '')
|
244 |
+
GEO_dict_forward['GEO_county'] = address.get('county', '')
|
245 |
+
GEO_dict_forward['GEO_state'] = address.get('state', '')
|
246 |
+
GEO_dict_forward['GEO_state_code'] = address.get('stateCode', '')
|
247 |
+
GEO_dict_forward['GEO_country'] = address.get('countryName', '')
|
248 |
+
GEO_dict_forward['GEO_country_code'] = address.get('countryCode', '')
|
249 |
+
GEO_dict_forward['GEO_continent'] = get_continent(address.get('countryName', ''))
|
250 |
+
|
251 |
+
### FORWARD locality
|
252 |
+
### Try forward, if failes, try reverse using deci, then verbatim
|
253 |
+
if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
|
254 |
+
GEO_dict = GEO_dict_rev
|
255 |
+
elif GEO_dict_rev_verbatim['GEO_city']:
|
256 |
+
GEO_dict = GEO_dict_rev_verbatim
|
257 |
+
elif GEO_dict_forward['GEO_city']:
|
258 |
+
GEO_dict = GEO_dict_forward
|
259 |
+
else:
|
260 |
+
response = requests.get(forward_url, params=params_forward_locality)
|
261 |
+
if response.status_code == 200:
|
262 |
+
data = response.json()
|
263 |
+
if data.get('items'):
|
264 |
+
first_result = data['items'][0]
|
265 |
+
GEO_dict_forward_locality['GEO_method'] = 'HERE_Geocode_forward_locality'
|
266 |
+
GEO_dict_forward_locality['GEO_formatted_full_string'] = first_result.get('title', '')
|
267 |
+
GEO_dict_forward_locality['GEO_decimal_lat'] = first_result['position']['lat']
|
268 |
+
GEO_dict_forward_locality['GEO_decimal_long'] = first_result['position']['lng']
|
269 |
+
|
270 |
+
address = first_result.get('address', {})
|
271 |
+
GEO_dict_forward_locality['GEO_city'] = address.get('city', '')
|
272 |
+
GEO_dict_forward_locality['GEO_county'] = address.get('county', '')
|
273 |
+
GEO_dict_forward_locality['GEO_state'] = address.get('state', '')
|
274 |
+
GEO_dict_forward_locality['GEO_state_code'] = address.get('stateCode', '')
|
275 |
+
GEO_dict_forward_locality['GEO_country'] = address.get('countryName', '')
|
276 |
+
GEO_dict_forward_locality['GEO_country_code'] = address.get('countryCode', '')
|
277 |
+
GEO_dict_forward_locality['GEO_continent'] = get_continent(address.get('countryName', ''))
|
278 |
+
|
279 |
+
|
280 |
+
# print(json.dumps(GEO_dict,indent=4))
|
281 |
+
|
282 |
+
|
283 |
+
# Pick the most detailed version
|
284 |
+
# if GEO_dict_rev['GEO_formatted_full_string'] and GEO_dict_forward['GEO_formatted_full_string']:
|
285 |
+
for loc in pinpoint:
|
286 |
+
rev = GEO_dict_rev.get(loc,'')
|
287 |
+
forward = GEO_dict_forward.get(loc,'')
|
288 |
+
forward_locality = GEO_dict_forward_locality.get(loc,'')
|
289 |
+
rev_verbatim = GEO_dict_rev_verbatim.get(loc,'')
|
290 |
+
|
291 |
+
if not rev and not forward and not forward_locality and not rev_verbatim:
|
292 |
+
pass
|
293 |
+
elif rev:
|
294 |
+
GEO_dict = GEO_dict_rev
|
295 |
+
break
|
296 |
+
elif forward:
|
297 |
+
GEO_dict = GEO_dict_forward
|
298 |
+
break
|
299 |
+
elif forward_locality:
|
300 |
+
GEO_dict = GEO_dict_forward_locality
|
301 |
+
break
|
302 |
+
elif rev_verbatim:
|
303 |
+
GEO_dict = GEO_dict_rev_verbatim
|
304 |
+
break
|
305 |
+
else:
|
306 |
+
GEO_dict = GEO_dict_null
|
307 |
+
|
308 |
+
|
309 |
+
if GEO_dict['GEO_formatted_full_string'] and replace_if_success_geo:
|
310 |
+
GEO_dict['GEO_override_OCR'] = True
|
311 |
+
record['country'] = GEO_dict.get('GEO_country')
|
312 |
+
record['stateProvince'] = GEO_dict.get('GEO_state')
|
313 |
+
record['county'] = GEO_dict.get('GEO_county')
|
314 |
+
record['municipality'] = GEO_dict.get('GEO_city')
|
315 |
+
|
316 |
+
# print(json.dumps(GEO_dict,indent=4))
|
317 |
+
return record, GEO_dict
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
validate_coordinates_here(None)
|
vouchervision/tool_taxonomy_WFO.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from urllib.parse import urlencode
|
3 |
+
from Levenshtein import ratio
|
4 |
+
from fuzzywuzzy import fuzz
|
5 |
+
|
6 |
+
class WFONameMatcher:
|
7 |
+
def __init__(self, tool_WFO):
|
8 |
+
self.base_url = "https://list.worldfloraonline.org/matching_rest.php?"
|
9 |
+
self.N_BEST_CANDIDATES = 10
|
10 |
+
self.NULL_DICT = {
|
11 |
+
"WFO_exact_match": False,
|
12 |
+
"WFO_exact_match_name": "",
|
13 |
+
"WFO_candidate_names": "",
|
14 |
+
"WFO_best_match": "",
|
15 |
+
"WFO_placement": "",
|
16 |
+
"WFO_override_OCR": False,
|
17 |
+
}
|
18 |
+
self.SEP = '|'
|
19 |
+
self.is_enabled = tool_WFO
|
20 |
+
|
21 |
+
def extract_input_string(self, record):
|
22 |
+
primary_input = f"{record.get('scientificName', '').strip()} {record.get('scientificNameAuthorship', '').strip()}".strip()
|
23 |
+
secondary_input = ' '.join(filter(None, [record.get('genus', '').strip(),
|
24 |
+
record.get('subgenus', '').strip(),
|
25 |
+
record.get('specificEpithet', '').strip(),
|
26 |
+
record.get('infraspecificEpithet', '').strip()])).strip()
|
27 |
+
|
28 |
+
return primary_input, secondary_input
|
29 |
+
|
30 |
+
def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
|
31 |
+
params = {
|
32 |
+
"input_string": input_string,
|
33 |
+
"check_homonyms": check_homonyms,
|
34 |
+
"check_rank": check_rank,
|
35 |
+
"method": "full",
|
36 |
+
"accept_single_candidate": accept_single_candidate,
|
37 |
+
}
|
38 |
+
|
39 |
+
full_url = self.base_url + urlencode(params)
|
40 |
+
|
41 |
+
response = requests.get(full_url)
|
42 |
+
if response.status_code == 200:
|
43 |
+
return response.json()
|
44 |
+
else:
|
45 |
+
return {"error": True, "message": "Failed to fetch data from WFO API"}
|
46 |
+
|
47 |
+
def query_and_process(self, record):
|
48 |
+
primary_input, secondary_input = self.extract_input_string(record)
|
49 |
+
|
50 |
+
# Query with primary input
|
51 |
+
primary_result = self.query_wfo_name_matching(primary_input)
|
52 |
+
primary_processed, primary_ranked_candidates = self.process_wfo_response(primary_result, primary_input)
|
53 |
+
|
54 |
+
if primary_processed.get('WFO_exact_match'):
|
55 |
+
print("Selected Primary --- Exact Primary & Unchecked Secondary")
|
56 |
+
return primary_processed
|
57 |
+
else:
|
58 |
+
# Query with secondary input
|
59 |
+
secondary_result = self.query_wfo_name_matching(secondary_input)
|
60 |
+
secondary_processed, secondary_ranked_candidates = self.process_wfo_response(secondary_result, secondary_input)
|
61 |
+
|
62 |
+
if secondary_processed.get('WFO_exact_match'):
|
63 |
+
print("Selected Secondary --- Unchecked Primary & Exact Secondary")
|
64 |
+
return secondary_processed
|
65 |
+
|
66 |
+
else:
|
67 |
+
# Both failed, just return the first failure
|
68 |
+
if (primary_processed.get("WFO_candidate_names") == '') and (secondary_processed.get("WFO_candidate_names") == ''):
|
69 |
+
print("Selected Primary --- Failed Primary & Failed Secondary")
|
70 |
+
return primary_processed
|
71 |
+
|
72 |
+
# 1st failed, just return the second
|
73 |
+
elif (primary_processed.get("WFO_candidate_names") == '') and (len(secondary_processed.get("WFO_candidate_names")) > 0):
|
74 |
+
print("Selected Secondary --- Failed Primary & Partial Secondary")
|
75 |
+
return secondary_processed
|
76 |
+
|
77 |
+
# 2nd failed, just return the first
|
78 |
+
elif (len(primary_processed.get("WFO_candidate_names")) > 0) and (secondary_processed.get("WFO_candidate_names") == ''):
|
79 |
+
print("Selected Primary --- Partial Primary & Failed Secondary")
|
80 |
+
return primary_processed
|
81 |
+
|
82 |
+
# Both have partial matches, compare and rerank
|
83 |
+
elif (len(primary_processed.get("WFO_candidate_names")) > 0) and (len(secondary_processed.get("WFO_candidate_names")) > 0):
|
84 |
+
# Combine and sort results, ensuring no duplicates
|
85 |
+
combined_candidates = list(set(primary_ranked_candidates + secondary_ranked_candidates))
|
86 |
+
combined_candidates.sort(key=lambda x: (x[1], x[0]), reverse=True) # Sort by similarity score, then name
|
87 |
+
|
88 |
+
# Replace candidates with combined_candidates and combined best match
|
89 |
+
best_score_primary = primary_processed["WFO_candidate_names"][0][1]
|
90 |
+
best_score_secondary = secondary_processed["WFO_candidate_names"][0][1]
|
91 |
+
|
92 |
+
# Extracting only the candidate names from the top candidates
|
93 |
+
top_candidates = combined_candidates[:self.N_BEST_CANDIDATES]
|
94 |
+
cleaned_candidates = [cand[0] for cand in top_candidates]
|
95 |
+
|
96 |
+
if best_score_primary >= best_score_secondary:
|
97 |
+
|
98 |
+
primary_processed["WFO_candidate_names"] = cleaned_candidates
|
99 |
+
primary_processed["WFO_best_match"] = cleaned_candidates[0]
|
100 |
+
|
101 |
+
response_placement = self.query_wfo_name_matching(primary_processed["WFO_best_match"])
|
102 |
+
placement_exact_match = response_placement.get("match")
|
103 |
+
primary_processed["WFO_placement"] = placement_exact_match.get("placement", '')
|
104 |
+
|
105 |
+
print("Selected Primary --- Partial Primary & Partial Secondary")
|
106 |
+
return primary_processed
|
107 |
+
else:
|
108 |
+
secondary_processed["WFO_candidate_names"] = cleaned_candidates
|
109 |
+
secondary_processed["WFO_best_match"] = cleaned_candidates[0]
|
110 |
+
|
111 |
+
response_placement = self.query_wfo_name_matching(secondary_processed["WFO_best_match"])
|
112 |
+
placement_exact_match = response_placement.get("match")
|
113 |
+
secondary_processed["WFO_placement"] = placement_exact_match.get("placement", '')
|
114 |
+
|
115 |
+
print("Selected Secondary --- Partial Primary & Partial Secondary")
|
116 |
+
return secondary_processed
|
117 |
+
else:
|
118 |
+
return self.NULL_DICT
|
119 |
+
|
120 |
+
def process_wfo_response(self, response, query):
|
121 |
+
simplified_response = {}
|
122 |
+
ranked_candidates = None
|
123 |
+
|
124 |
+
exact_match = response.get("match")
|
125 |
+
simplified_response["WFO_exact_match"] = bool(exact_match)
|
126 |
+
|
127 |
+
candidates = response.get("candidates", [])
|
128 |
+
candidate_names = [candidate["full_name_plain"] for candidate in candidates] if candidates else []
|
129 |
+
|
130 |
+
if not exact_match and candidate_names:
|
131 |
+
cleaned_candidates, ranked_candidates = self._rank_candidates_by_similarity(query, candidate_names)
|
132 |
+
simplified_response["WFO_candidate_names"] = cleaned_candidates
|
133 |
+
simplified_response["WFO_best_match"] = cleaned_candidates[0] if cleaned_candidates else ''
|
134 |
+
elif exact_match:
|
135 |
+
simplified_response["WFO_candidate_names"] = exact_match.get("full_name_plain")
|
136 |
+
simplified_response["WFO_best_match"] = exact_match.get("full_name_plain")
|
137 |
+
else:
|
138 |
+
simplified_response["WFO_candidate_names"] = ''
|
139 |
+
simplified_response["WFO_best_match"] = ''
|
140 |
+
|
141 |
+
# Call WFO again to update placement using WFO_best_match
|
142 |
+
try:
|
143 |
+
response_placement = self.query_wfo_name_matching(simplified_response["WFO_best_match"])
|
144 |
+
placement_exact_match = response_placement.get("match")
|
145 |
+
simplified_response["WFO_placement"] = placement_exact_match.get("placement", '')
|
146 |
+
except:
|
147 |
+
simplified_response["WFO_placement"] = ''
|
148 |
+
|
149 |
+
return simplified_response, ranked_candidates
|
150 |
+
|
151 |
+
def _rank_candidates_by_similarity(self, query, candidates):
|
152 |
+
string_similarities = []
|
153 |
+
fuzzy_similarities = {candidate: fuzz.ratio(query, candidate) for candidate in candidates}
|
154 |
+
query_words = query.split()
|
155 |
+
|
156 |
+
for candidate in candidates:
|
157 |
+
candidate_words = candidate.split()
|
158 |
+
# Calculate word similarities and sum them up
|
159 |
+
word_similarities = [ratio(query_word, candidate_word) for query_word, candidate_word in zip(query_words, candidate_words)]
|
160 |
+
total_word_similarity = sum(word_similarities)
|
161 |
+
|
162 |
+
# Calculate combined similarity score (average of word and fuzzy similarities)
|
163 |
+
fuzzy_similarity = fuzzy_similarities[candidate]
|
164 |
+
combined_similarity = (total_word_similarity + fuzzy_similarity) / 2
|
165 |
+
string_similarities.append((candidate, combined_similarity))
|
166 |
+
|
167 |
+
# Sort the candidates based on combined similarity, higher scores first
|
168 |
+
ranked_candidates = sorted(string_similarities, key=lambda x: x[1], reverse=True)
|
169 |
+
|
170 |
+
# Extracting only the candidate names from the top candidates
|
171 |
+
top_candidates = ranked_candidates[:self.N_BEST_CANDIDATES]
|
172 |
+
cleaned_candidates = [cand[0] for cand in top_candidates]
|
173 |
+
|
174 |
+
return cleaned_candidates, ranked_candidates
|
175 |
+
|
176 |
+
def check_WFO(self, record, replace_if_success_wfo):
|
177 |
+
if not self.is_enabled:
|
178 |
+
return record, self.NULL_DICT
|
179 |
+
|
180 |
+
else:
|
181 |
+
self.replace_if_success_wfo = replace_if_success_wfo
|
182 |
+
|
183 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
184 |
+
simplified_response = self.query_and_process(record)
|
185 |
+
simplified_response['WFO_override_OCR'] = False
|
186 |
+
|
187 |
+
# best_match
|
188 |
+
if simplified_response.get('WFO_exact_match'):
|
189 |
+
simplified_response['WFO_exact_match_name'] = simplified_response.get('WFO_best_match')
|
190 |
+
else:
|
191 |
+
simplified_response['WFO_exact_match_name'] = ''
|
192 |
+
|
193 |
+
# placement
|
194 |
+
wfo_placement = simplified_response.get('WFO_placement', '')
|
195 |
+
if wfo_placement:
|
196 |
+
parts = wfo_placement.split('/')[1:]
|
197 |
+
simplified_response['WFO_placement'] = self.SEP.join(parts)
|
198 |
+
else:
|
199 |
+
simplified_response['WFO_placement'] = ''
|
200 |
+
|
201 |
+
if simplified_response.get('WFO_exact_match') and replace_if_success_wfo:
|
202 |
+
simplified_response['WFO_override_OCR'] = True
|
203 |
+
name_parts = simplified_response.get('WFO_placement').split('$')[0]
|
204 |
+
name_parts = name_parts.split(self.SEP)
|
205 |
+
record['order'] = name_parts[3]
|
206 |
+
record['family'] = name_parts[4]
|
207 |
+
record['genus'] = name_parts[5]
|
208 |
+
record['specificEpithet'] = name_parts[6]
|
209 |
+
record['scientificName'] = simplified_response.get('WFO_exact_match_name')
|
210 |
+
|
211 |
+
return record, simplified_response
|
212 |
+
|
213 |
+
def validate_taxonomy_WFO(tool_WFO, record_dict, replace_if_success_wfo=False):
|
214 |
+
Matcher = WFONameMatcher(tool_WFO)
|
215 |
+
try:
|
216 |
+
record_dict, WFO_dict = Matcher.check_WFO(record_dict, replace_if_success_wfo)
|
217 |
+
return record_dict, WFO_dict
|
218 |
+
except:
|
219 |
+
return record_dict, Matcher.NULL_DICT
|
220 |
+
|
221 |
+
'''
|
222 |
+
if __name__ == "__main__":
|
223 |
+
Matcher = WFONameMatcher()
|
224 |
+
# input_string = "Rhopalocarpus alterfolius"
|
225 |
+
record_exact_match ={
|
226 |
+
"order": "Malpighiales",
|
227 |
+
"family": "Hypericaceae",
|
228 |
+
"scientificName": "Hypericum prolificum",
|
229 |
+
"scientificNameAuthorship": "",
|
230 |
+
|
231 |
+
"genus": "Hypericum",
|
232 |
+
"subgenus": "",
|
233 |
+
"specificEpithet": "prolificum",
|
234 |
+
"infraspecificEpithet": "",
|
235 |
+
}
|
236 |
+
record_partialPrimary_exactSecondary ={
|
237 |
+
"order": "Malpighiales",
|
238 |
+
"family": "Hypericaceae",
|
239 |
+
"scientificName": "Hyperic prolificum",
|
240 |
+
"scientificNameAuthorship": "",
|
241 |
+
|
242 |
+
"genus": "Hypericum",
|
243 |
+
"subgenus": "",
|
244 |
+
"specificEpithet": "prolificum",
|
245 |
+
"infraspecificEpithet": "",
|
246 |
+
}
|
247 |
+
record_exactPrimary_partialSecondary ={
|
248 |
+
"order": "Malpighiales",
|
249 |
+
"family": "Hypericaceae",
|
250 |
+
"scientificName": "Hypericum prolificum",
|
251 |
+
"scientificNameAuthorship": "",
|
252 |
+
|
253 |
+
"genus": "Hyperic",
|
254 |
+
"subgenus": "",
|
255 |
+
"specificEpithet": "prolificum",
|
256 |
+
"infraspecificEpithet": "",
|
257 |
+
}
|
258 |
+
record_partialPrimary_partialSecondary ={
|
259 |
+
"order": "Malpighiales",
|
260 |
+
"family": "Hypericaceae",
|
261 |
+
"scientificName": "Hyperic prolificum",
|
262 |
+
"scientificNameAuthorship": "",
|
263 |
+
|
264 |
+
"genus": "Hypericum",
|
265 |
+
"subgenus": "",
|
266 |
+
"specificEpithet": "prolific",
|
267 |
+
"infraspecificEpithet": "",
|
268 |
+
}
|
269 |
+
record_partialPrimary_partialSecondary_swap ={
|
270 |
+
"order": "Malpighiales",
|
271 |
+
"family": "Hypericaceae",
|
272 |
+
"scientificName": "Hypericum prolific",
|
273 |
+
"scientificNameAuthorship": "",
|
274 |
+
|
275 |
+
"genus": "Hyperic",
|
276 |
+
"subgenus": "",
|
277 |
+
"specificEpithet": "prolificum",
|
278 |
+
"infraspecificEpithet": "",
|
279 |
+
}
|
280 |
+
record_errorPrimary_partialSecondary ={
|
281 |
+
"order": "Malpighiales",
|
282 |
+
"family": "Hypericaceae",
|
283 |
+
"scientificName": "ricum proli",
|
284 |
+
"scientificNameAuthorship": "",
|
285 |
+
|
286 |
+
"genus": "Hyperic",
|
287 |
+
"subgenus": "",
|
288 |
+
"specificEpithet": "prolificum",
|
289 |
+
"infraspecificEpithet": "",
|
290 |
+
}
|
291 |
+
record_partialPrimary_errorSecondary ={
|
292 |
+
"order": "Malpighiales",
|
293 |
+
"family": "Hypericaceae",
|
294 |
+
"scientificName": "Hyperic prolificum",
|
295 |
+
"scientificNameAuthorship": "",
|
296 |
+
|
297 |
+
"genus": "ricum",
|
298 |
+
"subgenus": "",
|
299 |
+
"specificEpithet": "proli",
|
300 |
+
"infraspecificEpithet": "",
|
301 |
+
}
|
302 |
+
record_errorPrimary_errorSecondary ={
|
303 |
+
"order": "Malpighiales",
|
304 |
+
"family": "Hypericaceae",
|
305 |
+
"scientificName": "ricum proli",
|
306 |
+
"scientificNameAuthorship": "",
|
307 |
+
|
308 |
+
"genus": "ricum",
|
309 |
+
"subgenus": "",
|
310 |
+
"specificEpithet": "proli",
|
311 |
+
"infraspecificEpithet": "",
|
312 |
+
}
|
313 |
+
options = [record_exact_match,
|
314 |
+
record_partialPrimary_exactSecondary,
|
315 |
+
record_exactPrimary_partialSecondary,
|
316 |
+
record_partialPrimary_partialSecondary,
|
317 |
+
record_partialPrimary_partialSecondary_swap,
|
318 |
+
record_errorPrimary_partialSecondary,
|
319 |
+
record_partialPrimary_errorSecondary,
|
320 |
+
record_errorPrimary_errorSecondary]
|
321 |
+
for opt in options:
|
322 |
+
simplified_response = Matcher.check_WFO(opt)
|
323 |
+
print(json.dumps(simplified_response, indent=4))
|
324 |
+
'''
|
vouchervision/tool_wikipedia.py
CHANGED
@@ -8,7 +8,8 @@ import pstats
|
|
8 |
class WikipediaLinks():
|
9 |
|
10 |
|
11 |
-
def __init__(self, json_file_path_wiki) -> None:
|
|
|
12 |
self.json_file_path_wiki = json_file_path_wiki
|
13 |
self.wiki_wiki = wikipediaapi.Wikipedia(
|
14 |
user_agent='VoucherVision (merlin@example.com)',
|
@@ -466,54 +467,56 @@ class WikipediaLinks():
|
|
466 |
self.info_packet['WIKI_GEO'] = {}
|
467 |
self.info_packet['WIKI_LOCALITY'] = {}
|
468 |
|
469 |
-
|
470 |
-
county = output.get('county','')
|
471 |
-
stateProvince = output.get('stateProvince','')
|
472 |
-
country = output.get('country','')
|
473 |
|
474 |
-
|
|
|
|
|
|
|
475 |
|
476 |
-
|
477 |
-
family = output.get('family','')
|
478 |
-
scientificName = output.get('scientificName','')
|
479 |
-
genus = output.get('genus','')
|
480 |
-
specificEpithet = output.get('specificEpithet','')
|
481 |
|
|
|
|
|
|
|
|
|
|
|
482 |
|
483 |
-
query_geo = ' '.join([municipality, county, stateProvince, country]).strip()
|
484 |
-
query_locality = locality.strip()
|
485 |
-
query_taxa_primary = scientificName.strip()
|
486 |
-
query_taxa_secondary = ' '.join([genus, specificEpithet]).strip()
|
487 |
-
query_taxa_tertiary = ' '.join([order, family, genus, specificEpithet]).strip()
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
try:
|
501 |
-
self.gather_geo(query_locality,'locality')
|
502 |
-
except:
|
503 |
-
pass
|
504 |
-
|
505 |
-
queries_taxa = [query_taxa_primary, query_taxa_secondary, query_taxa_tertiary]
|
506 |
-
for q in queries_taxa:
|
507 |
-
if q:
|
508 |
try:
|
509 |
-
self.
|
510 |
-
break
|
511 |
except:
|
512 |
pass
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
try:
|
518 |
with open(self.json_file_path_wiki, 'w', encoding='utf-8') as file:
|
519 |
json.dump(self.info_packet, file, indent=4)
|
@@ -547,6 +550,13 @@ class WikipediaLinks():
|
|
547 |
return clean_text
|
548 |
|
549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
if __name__ == '__main__':
|
551 |
test_output = {
|
552 |
"filename": "MICH_7375774_Polygonaceae_Persicaria_",
|
|
|
8 |
class WikipediaLinks():
|
9 |
|
10 |
|
11 |
+
def __init__(self, tool_wikipedia, json_file_path_wiki) -> None:
|
12 |
+
self.is_enabled = tool_wikipedia
|
13 |
self.json_file_path_wiki = json_file_path_wiki
|
14 |
self.wiki_wiki = wikipediaapi.Wikipedia(
|
15 |
user_agent='VoucherVision (merlin@example.com)',
|
|
|
467 |
self.info_packet['WIKI_GEO'] = {}
|
468 |
self.info_packet['WIKI_LOCALITY'] = {}
|
469 |
|
470 |
+
if self.is_enabled:
|
|
|
|
|
|
|
471 |
|
472 |
+
municipality = output.get('municipality','')
|
473 |
+
county = output.get('county','')
|
474 |
+
stateProvince = output.get('stateProvince','')
|
475 |
+
country = output.get('country','')
|
476 |
|
477 |
+
locality = output.get('locality','')
|
|
|
|
|
|
|
|
|
478 |
|
479 |
+
order = output.get('order','')
|
480 |
+
family = output.get('family','')
|
481 |
+
scientificName = output.get('scientificName','')
|
482 |
+
genus = output.get('genus','')
|
483 |
+
specificEpithet = output.get('specificEpithet','')
|
484 |
|
|
|
|
|
|
|
|
|
|
|
485 |
|
486 |
+
query_geo = ' '.join([municipality, county, stateProvince, country]).strip()
|
487 |
+
query_locality = locality.strip()
|
488 |
+
query_taxa_primary = scientificName.strip()
|
489 |
+
query_taxa_secondary = ' '.join([genus, specificEpithet]).strip()
|
490 |
+
query_taxa_tertiary = ' '.join([order, family, genus, specificEpithet]).strip()
|
491 |
+
|
492 |
+
# query_taxa = "Tracaulon sagittatum Tracaulon sagittatum"
|
493 |
+
# query_geo = "Indiana Porter Co."
|
494 |
+
# query_locality = "Mical Springs edge"
|
495 |
+
|
496 |
+
if query_geo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
try:
|
498 |
+
self.gather_geo(query_geo)
|
|
|
499 |
except:
|
500 |
pass
|
501 |
+
|
502 |
+
if query_locality:
|
503 |
+
try:
|
504 |
+
self.gather_geo(query_locality,'locality')
|
505 |
+
except:
|
506 |
+
pass
|
507 |
+
|
508 |
+
queries_taxa = [query_taxa_primary, query_taxa_secondary, query_taxa_tertiary]
|
509 |
+
for q in queries_taxa:
|
510 |
+
if q:
|
511 |
+
try:
|
512 |
+
self.gather_taxonomy(q)
|
513 |
+
break
|
514 |
+
except:
|
515 |
+
pass
|
516 |
+
|
517 |
+
# print(self.info_packet)
|
518 |
+
# return self.info_packet
|
519 |
+
# self.gather_geo(query_geo)
|
520 |
try:
|
521 |
with open(self.json_file_path_wiki, 'w', encoding='utf-8') as file:
|
522 |
json.dump(self.info_packet, file, indent=4)
|
|
|
550 |
return clean_text
|
551 |
|
552 |
|
553 |
+
|
554 |
+
def validate_wikipedia(tool_wikipedia, json_file_path_wiki, output):
|
555 |
+
Wiki = WikipediaLinks(tool_wikipedia, json_file_path_wiki)
|
556 |
+
Wiki.gather_wikipedia_results(output)
|
557 |
+
|
558 |
+
|
559 |
+
|
560 |
if __name__ == '__main__':
|
561 |
test_output = {
|
562 |
"filename": "MICH_7375774_Polygonaceae_Persicaria_",
|
vouchervision/utils_LLM.py
CHANGED
@@ -8,6 +8,60 @@ import psutil
|
|
8 |
import threading
|
9 |
import torch
|
10 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
|
13 |
with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
|
@@ -19,6 +73,16 @@ def remove_colons_and_double_apostrophes(text):
|
|
19 |
return text.replace(":", "").replace("\"", "")
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def count_tokens(string, vendor, model_name):
|
24 |
full_string = string + JSON_FORMAT_INSTRUCTIONS
|
|
|
8 |
import threading
|
9 |
import torch
|
10 |
from datetime import datetime
|
11 |
+
from vouchervision.tool_taxonomy_WFO import validate_taxonomy_WFO, WFONameMatcher
|
12 |
+
from vouchervision.tool_geolocate_HERE import validate_coordinates_here
|
13 |
+
from vouchervision.tool_wikipedia import validate_wikipedia
|
14 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
15 |
+
|
16 |
+
|
17 |
+
def run_tools(output, tool_WFO, tool_GEO, tool_wikipedia, json_file_path_wiki):
|
18 |
+
# Define a function that will catch and return the results of your functions
|
19 |
+
def task(func, *args, **kwargs):
|
20 |
+
return func(*args, **kwargs)
|
21 |
+
|
22 |
+
# List of tasks to run in separate threads
|
23 |
+
tasks = [
|
24 |
+
(validate_taxonomy_WFO, (tool_WFO, output, False)),
|
25 |
+
(validate_coordinates_here, (tool_GEO, output, False)),
|
26 |
+
(validate_wikipedia, (tool_wikipedia, json_file_path_wiki, output)),
|
27 |
+
]
|
28 |
+
|
29 |
+
# Results storage
|
30 |
+
results = {}
|
31 |
+
|
32 |
+
# Use ThreadPoolExecutor to execute each function in its own thread
|
33 |
+
with ThreadPoolExecutor() as executor:
|
34 |
+
future_to_func = {executor.submit(task, func, *args): func.__name__ for func, args in tasks}
|
35 |
+
for future in as_completed(future_to_func):
|
36 |
+
func_name = future_to_func[future]
|
37 |
+
try:
|
38 |
+
# Collecting results
|
39 |
+
results[func_name] = future.result()
|
40 |
+
except Exception as exc:
|
41 |
+
print(f'{func_name} generated an exception: {exc}')
|
42 |
+
|
43 |
+
# Here, all threads have completed
|
44 |
+
# Extracting results
|
45 |
+
Matcher = WFONameMatcher(tool_WFO)
|
46 |
+
GEO_dict_null = {
|
47 |
+
'GEO_override_OCR': False,
|
48 |
+
'GEO_method': '',
|
49 |
+
'GEO_formatted_full_string': '',
|
50 |
+
'GEO_decimal_lat': '',
|
51 |
+
'GEO_decimal_long': '',
|
52 |
+
'GEO_city': '',
|
53 |
+
'GEO_county': '',
|
54 |
+
'GEO_state': '',
|
55 |
+
'GEO_state_code': '',
|
56 |
+
'GEO_country': '',
|
57 |
+
'GEO_country_code': '',
|
58 |
+
'GEO_continent': '',
|
59 |
+
}
|
60 |
+
output_WFO, WFO_record = results.get('validate_taxonomy_WFO', (output, Matcher.NULL_DICT))
|
61 |
+
output_GEO, GEO_record = results.get('validate_coordinates_here', (output, GEO_dict_null))
|
62 |
+
|
63 |
+
return output_WFO, WFO_record, output_GEO, GEO_record
|
64 |
+
|
65 |
|
66 |
def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
|
67 |
with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
|
|
|
73 |
return text.replace(":", "").replace("\"", "")
|
74 |
|
75 |
|
76 |
+
def sanitize_prompt(data):
|
77 |
+
if isinstance(data, dict):
|
78 |
+
return {sanitize_prompt(key): sanitize_prompt(value) for key, value in data.items()}
|
79 |
+
elif isinstance(data, list):
|
80 |
+
return [sanitize_prompt(element) for element in data]
|
81 |
+
elif isinstance(data, str):
|
82 |
+
return data.encode('utf-8', 'ignore').decode('utf-8')
|
83 |
+
else:
|
84 |
+
return data
|
85 |
+
|
86 |
|
87 |
def count_tokens(string, vendor, model_name):
|
88 |
full_string = string + JSON_FORMAT_INSTRUCTIONS
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -43,7 +43,7 @@ class VoucherVision():
|
|
43 |
self.prompt_version = None
|
44 |
self.is_hf = is_hf
|
45 |
|
46 |
-
self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
47 |
# self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
48 |
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
|
49 |
# self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
|
@@ -59,6 +59,8 @@ class VoucherVision():
|
|
59 |
self.logger.name = f'[Transcription]'
|
60 |
self.logger.info(f'Setting up OCR and LLM')
|
61 |
|
|
|
|
|
62 |
self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
|
63 |
self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
|
64 |
self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
|
@@ -83,7 +85,7 @@ class VoucherVision():
|
|
83 |
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
84 |
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
85 |
|
86 |
-
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
87 |
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
88 |
|
89 |
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
@@ -298,7 +300,8 @@ class VoucherVision():
|
|
298 |
break
|
299 |
|
300 |
|
301 |
-
def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report,
|
|
|
302 |
|
303 |
|
304 |
wb = openpyxl.load_workbook(path_transcription)
|
@@ -367,7 +370,17 @@ class VoucherVision():
|
|
367 |
sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
|
368 |
elif header.value == "run_name":
|
369 |
sheet.cell(row=next_row, column=i, value=Dirs.run_name)
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
372 |
elif header.value in self.wfo_headers_no_lists:
|
373 |
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
@@ -404,10 +417,11 @@ class VoucherVision():
|
|
404 |
|
405 |
|
406 |
def has_API_key(self, val):
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
|
|
411 |
|
412 |
|
413 |
def get_google_credentials(self): # Also used for google drive
|
@@ -460,6 +474,7 @@ class VoucherVision():
|
|
460 |
|
461 |
self.has_key_openai = self.has_API_key(k_openai)
|
462 |
self.has_key_azure_openai = self.has_API_key(k_openai_azure)
|
|
|
463 |
|
464 |
self.has_key_google_project_id = self.has_API_key(k_google_project_id)
|
465 |
self.has_key_google_location = self.has_API_key(k_google_location)
|
@@ -470,12 +485,15 @@ class VoucherVision():
|
|
470 |
self.has_key_open_cage_geocode = self.has_API_key(k_opencage)
|
471 |
|
472 |
|
|
|
473 |
### Google - OCR, Palm2, Gemini
|
474 |
if self.has_key_google_application_credentials and self.has_key_google_project_id and self.has_key_google_location:
|
475 |
if self.is_hf:
|
476 |
vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'), credentials=self.get_google_credentials())
|
477 |
else:
|
478 |
vertexai.init(project=k_google_project_id, location=k_google_location, credentials=self.get_google_credentials())
|
|
|
|
|
479 |
|
480 |
### OpenAI
|
481 |
if self.has_key_openai:
|
@@ -497,7 +515,6 @@ class VoucherVision():
|
|
497 |
azure_endpoint = os.getenv('AZURE_API_BASE'),
|
498 |
openai_organization = os.getenv('AZURE_ORGANIZATION'),
|
499 |
)
|
500 |
-
self.has_key_azure_openai = True
|
501 |
|
502 |
else:
|
503 |
# Initialize the Azure OpenAI client
|
@@ -508,7 +525,6 @@ class VoucherVision():
|
|
508 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
509 |
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
510 |
)
|
511 |
-
self.has_key_azure_openai = True
|
512 |
|
513 |
|
514 |
### Mistral
|
@@ -624,6 +640,7 @@ class VoucherVision():
|
|
624 |
ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
|
625 |
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
626 |
self.OCR = ocr_google.OCR
|
|
|
627 |
|
628 |
self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
|
629 |
|
@@ -671,7 +688,7 @@ class VoucherVision():
|
|
671 |
|
672 |
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
673 |
json_report.set_JSON({}, {}, {})
|
674 |
-
llm_model = self.initialize_llm_model(self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
|
675 |
|
676 |
for i, path_to_crop in enumerate(self.img_paths):
|
677 |
self.update_progress_report_batch(progress_report, i)
|
@@ -729,7 +746,7 @@ class VoucherVision():
|
|
729 |
|
730 |
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
|
731 |
|
732 |
-
self.
|
733 |
|
734 |
json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
|
735 |
|
@@ -741,22 +758,22 @@ class VoucherVision():
|
|
741 |
##################################################################################################################################
|
742 |
################################################## LLM Helper Funcs ##############################################################
|
743 |
##################################################################################################################################
|
744 |
-
def initialize_llm_model(self, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
|
745 |
if 'LOCAL'in name_parts:
|
746 |
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
747 |
if 'CPU' in name_parts:
|
748 |
-
return LocalCPUMistralHandler(logger, model_name, JSON_dict_structure)
|
749 |
else:
|
750 |
-
return LocalMistralHandler(logger, model_name, JSON_dict_structure)
|
751 |
else:
|
752 |
if 'PALM2' in name_parts:
|
753 |
-
return GooglePalm2Handler(logger, model_name, JSON_dict_structure)
|
754 |
elif 'GEMINI' in name_parts:
|
755 |
-
return GoogleGeminiHandler(logger, model_name, JSON_dict_structure)
|
756 |
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
757 |
-
return MistralHandler(logger, model_name, JSON_dict_structure)
|
758 |
else:
|
759 |
-
return OpenAIHandler(logger, model_name, JSON_dict_structure, is_azure, llm_object)
|
760 |
|
761 |
def setup_prompt(self):
|
762 |
Catalog = PromptCatalog()
|
@@ -807,11 +824,6 @@ class VoucherVision():
|
|
807 |
return final_JSON_response_updated, WFO_record, GEO_record
|
808 |
|
809 |
|
810 |
-
def log_completion_info(self, final_JSON_response):
|
811 |
-
self.logger.info(f'Formatted JSON\n{final_JSON_response}')
|
812 |
-
self.logger.info(f'Finished API calls\n')
|
813 |
-
|
814 |
-
|
815 |
def update_progress_report_final(self, progress_report):
|
816 |
if progress_report is not None:
|
817 |
progress_report.reset_batch("Batch Complete")
|
@@ -839,7 +851,8 @@ class VoucherVision():
|
|
839 |
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
|
840 |
|
841 |
|
842 |
-
def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report,
|
|
|
843 |
if response is None:
|
844 |
response = self.JSON_dict_structure
|
845 |
# Insert 'filename' as the first key
|
|
|
43 |
self.prompt_version = None
|
44 |
self.is_hf = is_hf
|
45 |
|
46 |
+
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
47 |
# self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
48 |
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
|
49 |
# self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
|
|
|
59 |
self.logger.name = f'[Transcription]'
|
60 |
self.logger.info(f'Setting up OCR and LLM')
|
61 |
|
62 |
+
self.trOCR_model_version = self.cfg['leafmachine']['project']['trOCR_model_path']
|
63 |
+
|
64 |
self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
|
65 |
self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
|
66 |
self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
|
|
|
85 |
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
86 |
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
87 |
|
88 |
+
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "LM2_collage", "OCR_method", "OCR_double", "OCR_trOCR", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
89 |
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
90 |
|
91 |
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
|
|
300 |
break
|
301 |
|
302 |
|
303 |
+
def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report,
|
304 |
+
MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
305 |
|
306 |
|
307 |
wb = openpyxl.load_workbook(path_transcription)
|
|
|
370 |
sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
|
371 |
elif header.value == "run_name":
|
372 |
sheet.cell(row=next_row, column=i, value=Dirs.run_name)
|
373 |
+
elif header.value == "LM2_collage":
|
374 |
+
sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['use_RGB_label_images'])
|
375 |
+
elif header.value == "OCR_method":
|
376 |
+
value_to_insert = self.cfg['leafmachine']['project']['OCR_option']
|
377 |
+
if isinstance(value_to_insert, list):
|
378 |
+
value_to_insert = '|'.join(map(str, value_to_insert))
|
379 |
+
sheet.cell(row=next_row, column=i, value=value_to_insert)
|
380 |
+
elif header.value == "OCR_double":
|
381 |
+
sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['double_OCR'])
|
382 |
+
elif header.value == "OCR_trOCR":
|
383 |
+
sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['do_use_trOCR'])
|
384 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
385 |
elif header.value in self.wfo_headers_no_lists:
|
386 |
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
|
|
417 |
|
418 |
|
419 |
def has_API_key(self, val):
|
420 |
+
return isinstance(val, str) and bool(val.strip())
|
421 |
+
# if val != '':
|
422 |
+
# return True
|
423 |
+
# else:
|
424 |
+
# return False
|
425 |
|
426 |
|
427 |
def get_google_credentials(self): # Also used for google drive
|
|
|
474 |
|
475 |
self.has_key_openai = self.has_API_key(k_openai)
|
476 |
self.has_key_azure_openai = self.has_API_key(k_openai_azure)
|
477 |
+
self.llm = None
|
478 |
|
479 |
self.has_key_google_project_id = self.has_API_key(k_google_project_id)
|
480 |
self.has_key_google_location = self.has_API_key(k_google_location)
|
|
|
485 |
self.has_key_open_cage_geocode = self.has_API_key(k_opencage)
|
486 |
|
487 |
|
488 |
+
|
489 |
### Google - OCR, Palm2, Gemini
|
490 |
if self.has_key_google_application_credentials and self.has_key_google_project_id and self.has_key_google_location:
|
491 |
if self.is_hf:
|
492 |
vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'), credentials=self.get_google_credentials())
|
493 |
else:
|
494 |
vertexai.init(project=k_google_project_id, location=k_google_location, credentials=self.get_google_credentials())
|
495 |
+
os.environ['GOOGLE_API_KEY'] = self.cfg_private['google']['GOOGLE_PALM_API']
|
496 |
+
|
497 |
|
498 |
### OpenAI
|
499 |
if self.has_key_openai:
|
|
|
515 |
azure_endpoint = os.getenv('AZURE_API_BASE'),
|
516 |
openai_organization = os.getenv('AZURE_ORGANIZATION'),
|
517 |
)
|
|
|
518 |
|
519 |
else:
|
520 |
# Initialize the Azure OpenAI client
|
|
|
525 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
526 |
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
527 |
)
|
|
|
528 |
|
529 |
|
530 |
### Mistral
|
|
|
640 |
ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
|
641 |
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
642 |
self.OCR = ocr_google.OCR
|
643 |
+
self.logger.info(f"Complete OCR text for LLM prompt:\n\n{self.OCR}\n\n")
|
644 |
|
645 |
self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
|
646 |
|
|
|
688 |
|
689 |
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
690 |
json_report.set_JSON({}, {}, {})
|
691 |
+
llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
|
692 |
|
693 |
for i, path_to_crop in enumerate(self.img_paths):
|
694 |
self.update_progress_report_batch(progress_report, i)
|
|
|
746 |
|
747 |
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
|
748 |
|
749 |
+
self.logger.info(f'Finished LLM call')
|
750 |
|
751 |
json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
|
752 |
|
|
|
758 |
##################################################################################################################################
|
759 |
################################################## LLM Helper Funcs ##############################################################
|
760 |
##################################################################################################################################
|
761 |
+
def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
|
762 |
if 'LOCAL'in name_parts:
|
763 |
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
764 |
if 'CPU' in name_parts:
|
765 |
+
return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure)
|
766 |
else:
|
767 |
+
return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure)
|
768 |
else:
|
769 |
if 'PALM2' in name_parts:
|
770 |
+
return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure)
|
771 |
elif 'GEMINI' in name_parts:
|
772 |
+
return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure)
|
773 |
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
774 |
+
return MistralHandler(cfg, logger, model_name, JSON_dict_structure)
|
775 |
else:
|
776 |
+
return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object)
|
777 |
|
778 |
def setup_prompt(self):
|
779 |
Catalog = PromptCatalog()
|
|
|
824 |
return final_JSON_response_updated, WFO_record, GEO_record
|
825 |
|
826 |
|
|
|
|
|
|
|
|
|
|
|
827 |
def update_progress_report_final(self, progress_report):
|
828 |
if progress_report is not None:
|
829 |
progress_report.reset_batch("Batch Complete")
|
|
|
851 |
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
|
852 |
|
853 |
|
854 |
+
def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report,
|
855 |
+
MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
856 |
if response is None:
|
857 |
response = self.JSON_dict_structure
|
858 |
# Insert 'filename' as the first key
|