phyloforfun commited on
Commit
c5e57d6
·
1 Parent(s): 712822d

July 18 update

Browse files
app.py CHANGED
@@ -144,6 +144,8 @@ if 'present_annotations' not in st.session_state:
144
  st.session_state['present_annotations'] = None
145
  if 'missing_annotations' not in st.session_state:
146
  st.session_state['missing_annotations'] = None
 
 
147
  if 'date_of_check' not in st.session_state:
148
  st.session_state['date_of_check'] = None
149
 
@@ -1016,6 +1018,16 @@ def create_private_file():
1016
  st.write("Leave keys blank if you do not intend to use that service.")
1017
  st.info("Note: You can manually edit these API keys later by opening the /PRIVATE_DATA.yaml file in a plain text editor.")
1018
 
 
 
 
 
 
 
 
 
 
 
1019
  st.write("---")
1020
  st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
1021
  st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
@@ -1170,6 +1182,7 @@ def create_private_file():
1170
  st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
1171
  args=[cfg_private,
1172
  openai_api_key,
 
1173
  azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
1174
  google_application_credentials, google_project_location, google_project_id,
1175
  mistral_API_KEY,
@@ -1183,12 +1196,15 @@ def create_private_file():
1183
 
1184
  def save_changes_to_API_keys(cfg_private,
1185
  openai_api_key,
 
1186
  azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
1187
  google_application_credentials, google_project_location, google_project_id,
1188
  mistral_API_KEY,
1189
  here_APP_ID, here_API_KEY):
1190
 
1191
  # Update the configuration dictionary with the new values
 
 
1192
  cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
1193
 
1194
  cfg_private['openai_azure']['OPENAI_API_VERSION'] = azure_openai_api_version
@@ -1269,8 +1285,19 @@ def display_api_key_status(ccol):
1269
  # Convert keys to annotations (similar to what you do in check_api_key_status)
1270
  present_annotations = []
1271
  missing_annotations = []
 
1272
  for key in present_keys:
1273
- if "Valid" in key:
 
 
 
 
 
 
 
 
 
 
1274
  show_text = key.split('(')[0]
1275
  present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
1276
  elif "Invalid" in key:
@@ -1279,6 +1306,7 @@ def display_api_key_status(ccol):
1279
 
1280
  st.session_state['present_annotations'] = present_annotations
1281
  st.session_state['missing_annotations'] = missing_annotations
 
1282
  st.session_state['date_of_check'] = date_of_check
1283
  st.session_state['API_checked'] = True
1284
  # print('for')
@@ -1307,6 +1335,14 @@ def display_api_key_status(ccol):
1307
  if 'missing_annotations' in st.session_state and st.session_state['missing_annotations']:
1308
  annotated_text(*st.session_state['missing_annotations'])
1309
 
 
 
 
 
 
 
 
 
1310
 
1311
 
1312
  def check_api_key_status():
@@ -1322,8 +1358,19 @@ def check_api_key_status():
1322
  # Prepare annotations for present keys
1323
  present_annotations = []
1324
  missing_annotations = []
 
1325
  for key in present_keys:
1326
- if "Valid" in key:
 
 
 
 
 
 
 
 
 
 
1327
  show_text = key.split('(')[0]
1328
  present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
1329
  elif "Invalid" in key:
@@ -1340,6 +1387,7 @@ def check_api_key_status():
1340
 
1341
  st.session_state['present_annotations'] = present_annotations
1342
  st.session_state['missing_annotations'] = missing_annotations
 
1343
  st.session_state['date_of_check'] = date_of_check
1344
 
1345
 
@@ -1831,7 +1879,7 @@ def content_ocr_method():
1831
  demo_text_trh = demo_text_h + '\n' + demo_text_tr
1832
  demo_text_trp = demo_text_p + '\n' + demo_text_tr
1833
 
1834
- options = ["Google Vision Handwritten", "Google Vision Printed", "CRAFT + trOCR","LLaVA"]
1835
  options_llava = ["llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",]
1836
  options_llava_bit = ["full", "4bit",]
1837
  captions_llava = [
@@ -1882,6 +1930,7 @@ def content_ocr_method():
1882
  "Google Vision Printed": 'normal',
1883
  "CRAFT + trOCR": 'CRAFT',
1884
  "LLaVA": 'LLaVA',
 
1885
  }
1886
 
1887
  # Map selected options to their corresponding internal representations
@@ -1914,6 +1963,19 @@ def content_ocr_method():
1914
  else:
1915
  st.session_state.config['leafmachine']['project']['trOCR_model_path'] = user_input_trOCR_model_path
1916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1917
  if 'LLaVA' in selected_OCR_options:
1918
  OCR_option_llava = st.radio(
1919
  "Select the LLaVA version",
 
144
  st.session_state['present_annotations'] = None
145
  if 'missing_annotations' not in st.session_state:
146
  st.session_state['missing_annotations'] = None
147
+ if 'model_annotations' not in st.session_state:
148
+ st.session_state['model_annotations'] = None
149
  if 'date_of_check' not in st.session_state:
150
  st.session_state['date_of_check'] = None
151
 
 
1018
  st.write("Leave keys blank if you do not intend to use that service.")
1019
  st.info("Note: You can manually edit these API keys later by opening the /PRIVATE_DATA.yaml file in a plain text editor.")
1020
 
1021
+ st.write("---")
1022
+ st.subheader("Hugging Face (*Required For Local LLMs*)")
1023
+ st.markdown("VoucherVision relies on LLM models from Hugging Face. Some models are 'gated', meaning that you have to agree to the creator's usage guidelines.")
1024
+ st.markdown("""Create a [Hugging Face account](https://huggingface.co/join). Once your account is created, in your profile settings [navigate to 'Access Tokens'](https://huggingface.co/settings/tokens) and click 'Create new token'. Create a token that has 'Read' privileges. Copy the token into the field below.""")
1025
+
1026
+ hugging_face_token = st.text_input(label = 'Hugging Face token', value = cfg_private['huggingface'].get('hf_token', ''),
1027
+ placeholder = 'e.g. hf_GNRLIUBnvfkjvnf....',
1028
+ help ="This is your Hugging Face access token. It only needs Read access. Please see https://huggingface.co/settings/tokens",
1029
+ type='password')
1030
+
1031
  st.write("---")
1032
  st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
1033
  st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
 
1182
  st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
1183
  args=[cfg_private,
1184
  openai_api_key,
1185
+ hugging_face_token,
1186
  azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
1187
  google_application_credentials, google_project_location, google_project_id,
1188
  mistral_API_KEY,
 
1196
 
1197
  def save_changes_to_API_keys(cfg_private,
1198
  openai_api_key,
1199
+ hugging_face_token,
1200
  azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
1201
  google_application_credentials, google_project_location, google_project_id,
1202
  mistral_API_KEY,
1203
  here_APP_ID, here_API_KEY):
1204
 
1205
  # Update the configuration dictionary with the new values
1206
+ cfg_private['huggingface']['hf_token'] = hugging_face_token
1207
+
1208
  cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
1209
 
1210
  cfg_private['openai_azure']['OPENAI_API_VERSION'] = azure_openai_api_version
 
1285
  # Convert keys to annotations (similar to what you do in check_api_key_status)
1286
  present_annotations = []
1287
  missing_annotations = []
1288
+ model_annotations = []
1289
  for key in present_keys:
1290
+ if "[MODEL]" in key:
1291
+ show_text = key.split(']')[1]
1292
+ show_text = show_text.split('(')[0]
1293
+ if 'Under Review' in key:
1294
+ model_annotations.append((show_text, "under review", "#9C0586")) # Green for valid
1295
+ elif 'invalid' in key:
1296
+ model_annotations.append((show_text, "error!", "#870307")) # Green for valid
1297
+ else:
1298
+ model_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
1299
+
1300
+ elif "Valid" in key:
1301
  show_text = key.split('(')[0]
1302
  present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
1303
  elif "Invalid" in key:
 
1306
 
1307
  st.session_state['present_annotations'] = present_annotations
1308
  st.session_state['missing_annotations'] = missing_annotations
1309
+ st.session_state['model_annotations'] = model_annotations
1310
  st.session_state['date_of_check'] = date_of_check
1311
  st.session_state['API_checked'] = True
1312
  # print('for')
 
1335
  if 'missing_annotations' in st.session_state and st.session_state['missing_annotations']:
1336
  annotated_text(*st.session_state['missing_annotations'])
1337
 
1338
+ if not st.session_state['is_hf']:
1339
+ st.markdown(f"Access to Hugging Face Models")
1340
+
1341
+ if 'model_annotations' in st.session_state and st.session_state['model_annotations']:
1342
+ annotated_text(*st.session_state['model_annotations'])
1343
+
1344
+
1345
+
1346
 
1347
 
1348
  def check_api_key_status():
 
1358
  # Prepare annotations for present keys
1359
  present_annotations = []
1360
  missing_annotations = []
1361
+ model_annotations = []
1362
  for key in present_keys:
1363
+ if "[MODEL]" in key:
1364
+ show_text = key.split(']')[1]
1365
+ show_text = show_text.split('(')[0]
1366
+ if 'Under Review' in key:
1367
+ model_annotations.append((show_text, "under review", "#9C0586")) # Green for valid
1368
+ elif 'invalid' in key:
1369
+ model_annotations.append((show_text, "error!", "#870307")) # Green for valid
1370
+ else:
1371
+ model_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
1372
+
1373
+ elif "Valid" in key:
1374
  show_text = key.split('(')[0]
1375
  present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
1376
  elif "Invalid" in key:
 
1387
 
1388
  st.session_state['present_annotations'] = present_annotations
1389
  st.session_state['missing_annotations'] = missing_annotations
1390
+ st.session_state['model_annotations'] = model_annotations
1391
  st.session_state['date_of_check'] = date_of_check
1392
 
1393
 
 
1879
  demo_text_trh = demo_text_h + '\n' + demo_text_tr
1880
  demo_text_trp = demo_text_p + '\n' + demo_text_tr
1881
 
1882
+ options = ["Google Vision Handwritten", "Google Vision Printed", "CRAFT + trOCR","LLaVA", "Florence-2"]
1883
  options_llava = ["llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",]
1884
  options_llava_bit = ["full", "4bit",]
1885
  captions_llava = [
 
1930
  "Google Vision Printed": 'normal',
1931
  "CRAFT + trOCR": 'CRAFT',
1932
  "LLaVA": 'LLaVA',
1933
+ "Florence-2": 'Florence-2',
1934
  }
1935
 
1936
  # Map selected options to their corresponding internal representations
 
1963
  else:
1964
  st.session_state.config['leafmachine']['project']['trOCR_model_path'] = user_input_trOCR_model_path
1965
 
1966
+
1967
+ if "Florence-2" in selected_OCR_options:
1968
+ default_florence_model_path = st.session_state.config['leafmachine']['project']['florence_model_path']
1969
+ user_input_florence_model_path = st.text_input("Florence-2 Hugging Face model path. MUST be a Florence-2 version based on 'microsoft/Florence-2-large' or similar.", value=default_florence_model_path)
1970
+
1971
+ if st.session_state.config['leafmachine']['project']['florence_model_path'] != user_input_florence_model_path:
1972
+ is_valid_mp = is_valid_huggingface_model_path(user_input_florence_model_path)
1973
+ if not is_valid_mp:
1974
+ st.error(f"The Hugging Face model path {user_input_florence_model_path} is not valid. Please revise.")
1975
+ else:
1976
+ st.session_state.config['leafmachine']['project']['florence_model_path'] = user_input_florence_model_path
1977
+
1978
+
1979
  if 'LLaVA' in selected_OCR_options:
1980
  OCR_option_llava = st.radio(
1981
  "Select the LLaVA version",
img/collage.jpg ADDED

Git LFS Details

  • SHA256: 39a0878254236cd6223efe677838cc2034a55f0278cf4d2b539937609d3c34e2
  • Pointer size: 131 Bytes
  • Size of remote file: 375 kB
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
vouchervision/API_validation.py CHANGED
@@ -1,4 +1,5 @@
1
  import os, io, openai, vertexai, json, tempfile
 
2
  from mistralai.client import MistralClient
3
  from mistralai.models.chat_completion import ChatMessage
4
  from langchain.schema import HumanMessage
@@ -9,7 +10,7 @@ from google.cloud import vision
9
  from google.cloud import vision_v1p3beta1 as vision_beta
10
  # from langchain_google_genai import ChatGoogleGenerativeAI
11
  from langchain_google_vertexai import VertexAI
12
-
13
 
14
  from datetime import datetime
15
  # import google.generativeai as genai
@@ -17,6 +18,8 @@ from google.oauth2 import service_account
17
  # from googleapiclient.discovery import build
18
 
19
 
 
 
20
  class APIvalidation:
21
 
22
  def __init__(self, cfg_private, dir_home, is_hf) -> None:
@@ -25,6 +28,13 @@ class APIvalidation:
25
  self.is_hf = is_hf
26
  self.formatted_date = self.get_formatted_date()
27
 
 
 
 
 
 
 
 
28
  def get_formatted_date(self):
29
  # Get the current date
30
  current_date = datetime.now()
@@ -59,7 +69,7 @@ class APIvalidation:
59
  try:
60
  # Initialize the Azure OpenAI client
61
  model = AzureChatOpenAI(
62
- deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
63
  openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
64
  openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
65
  azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
@@ -67,7 +77,7 @@ class APIvalidation:
67
  )
68
  msg = HumanMessage(content="hello")
69
  # self.llm_object.temperature = self.config.get('temperature')
70
- response = model([msg])
71
 
72
  # Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
73
  if response:
@@ -85,7 +95,7 @@ class APIvalidation:
85
  azure_organization = os.getenv('AZURE_ORGANIZATION')
86
  # Initialize the Azure OpenAI client
87
  model = AzureChatOpenAI(
88
- deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
89
  openai_api_version = azure_api_version,
90
  openai_api_key = azure_api_key,
91
  azure_endpoint = azure_api_base,
@@ -93,7 +103,7 @@ class APIvalidation:
93
  )
94
  msg = HumanMessage(content="hello")
95
  # self.llm_object.temperature = self.config.get('temperature')
96
- response = model([msg])
97
 
98
  # Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
99
  if response:
@@ -223,8 +233,55 @@ class APIvalidation:
223
 
224
  return results
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
 
 
 
228
  def get_google_credentials(self):
229
  if self.is_hf:
230
  creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
@@ -251,6 +308,8 @@ class APIvalidation:
251
  k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
252
  k_project_id = os.getenv('GOOGLE_PROJECT_ID')
253
  k_location = os.getenv('GOOGLE_LOCATION')
 
 
254
 
255
  k_mistral = os.getenv('MISTRAL_API_KEY')
256
  k_here = os.getenv('HERE_API_KEY')
@@ -259,6 +318,8 @@ class APIvalidation:
259
  k_OPENAI_API_KEY = self.cfg_private['openai']['OPENAI_API_KEY']
260
  k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
261
 
 
 
262
  k_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
263
  k_location = self.cfg_private['google']['GOOGLE_LOCATION']
264
  k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
@@ -284,6 +345,29 @@ class APIvalidation:
284
  present_keys.append('Google OCR Handwriting (Invalid)')
285
  else:
286
  missing_keys.append('Google OCR')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
 
289
  # OpenAI key check
@@ -297,14 +381,14 @@ class APIvalidation:
297
  missing_keys.append('OpenAI')
298
 
299
  # Azure OpenAI key check
300
- if self.has_API_key(k_openai_azure):
301
- is_valid = self.check_azure_openai_api_key()
302
- if is_valid:
303
- present_keys.append('Azure OpenAI (Valid)')
304
- else:
305
- present_keys.append('Azure OpenAI (Invalid)')
306
- else:
307
- missing_keys.append('Azure OpenAI')
308
 
309
  # Google PALM2/Gemini key check
310
  if self.has_API_key(k_google_application_credentials) and self.has_API_key(k_project_id) and self.has_API_key(k_location): ##################
 
1
  import os, io, openai, vertexai, json, tempfile
2
+ import webbrowser
3
  from mistralai.client import MistralClient
4
  from mistralai.models.chat_completion import ChatMessage
5
  from langchain.schema import HumanMessage
 
10
  from google.cloud import vision_v1p3beta1 as vision_beta
11
  # from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_google_vertexai import VertexAI
13
+ from huggingface_hub import HfApi, HfFolder
14
 
15
  from datetime import datetime
16
  # import google.generativeai as genai
 
18
  # from googleapiclient.discovery import build
19
 
20
 
21
+
22
+
23
  class APIvalidation:
24
 
25
  def __init__(self, cfg_private, dir_home, is_hf) -> None:
 
28
  self.is_hf = is_hf
29
  self.formatted_date = self.get_formatted_date()
30
 
31
+ self.HF_MODEL_LIST = ['microsoft/Florence-2-large','microsoft/Florence-2-base',
32
+ 'microsoft/trocr-base-handwritten','microsoft/trocr-large-handwritten',
33
+ 'google/gemma-2-9b','google/gemma-2-9b-it','google/gemma-2-27b','google/gemma-2-27b-it',
34
+ 'mistralai/Mistral-7B-Instruct-v0.3','mistralai/Mixtral-8x22B-v0.1','mistralai/Mixtral-8x22B-Instruct-v0.1',
35
+ 'unsloth/mistral-7b-instruct-v0.3-bnb-4bit'
36
+ ]
37
+
38
  def get_formatted_date(self):
39
  # Get the current date
40
  current_date = datetime.now()
 
69
  try:
70
  # Initialize the Azure OpenAI client
71
  model = AzureChatOpenAI(
72
+ deployment_name = 'gpt-4',#'gpt-35-turbo',
73
  openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
74
  openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
75
  azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
 
77
  )
78
  msg = HumanMessage(content="hello")
79
  # self.llm_object.temperature = self.config.get('temperature')
80
+ response = model.invoke([msg])
81
 
82
  # Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
83
  if response:
 
95
  azure_organization = os.getenv('AZURE_ORGANIZATION')
96
  # Initialize the Azure OpenAI client
97
  model = AzureChatOpenAI(
98
+ deployment_name = 'gpt-4',#'gpt-35-turbo',
99
  openai_api_version = azure_api_version,
100
  openai_api_key = azure_api_key,
101
  azure_endpoint = azure_api_base,
 
103
  )
104
  msg = HumanMessage(content="hello")
105
  # self.llm_object.temperature = self.config.get('temperature')
106
+ response = model.invoke([msg])
107
 
108
  # Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
109
  if response:
 
233
 
234
  return results
235
 
236
+ def test_hf_token(self, k_huggingface):
237
+ if not k_huggingface:
238
+ print("Hugging Face API token not found in environment variables.")
239
+ return False
240
+
241
+ # Create an instance of the API
242
+ api = HfApi()
243
+
244
+ try:
245
+ # Try to get details of a known public model
246
+ model_info = api.model_info("bert-base-uncased", use_auth_token=k_huggingface)
247
+ if model_info:
248
+ print("Token is valid. Accessed model details successfully.")
249
+ return True
250
+ else:
251
+ print("Token is valid but failed to access model details.")
252
+ return True
253
+ except Exception as e:
254
+ print(f"Failed to validate token: {e}")
255
+ return False
256
+
257
+ def check_gated_model_access(self, model_id, k_huggingface):
258
+ api = HfApi()
259
+ attempts = 0
260
+ max_attempts = 2
261
+
262
+ while attempts < max_attempts:
263
+ try:
264
+ model_info = api.model_info(model_id, use_auth_token=k_huggingface)
265
+ print(f"Access to model '{model_id}' is granted.")
266
+ return "valid"
267
+ except Exception as e:
268
+ error_message = str(e)
269
+ if 'awaiting a review' in error_message:
270
+ print(f"Access to model '{model_id}' is awaiting review. (Under Review)")
271
+ return "under_review"
272
+ print(f"Access to model '{model_id}' is denied. Please accept the terms and conditions.")
273
+ print(f"Error: {e}")
274
+ webbrowser.open(f"https://huggingface.co/{model_id}")
275
+ input("Press Enter after you have accepted the terms and conditions...")
276
+
277
+ attempts += 1
278
+
279
+ print(f"Failed to access model '{model_id}' after {max_attempts} attempts.")
280
+ return "invalid"
281
 
282
 
283
+
284
+
285
  def get_google_credentials(self):
286
  if self.is_hf:
287
  creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
 
308
  k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
309
  k_project_id = os.getenv('GOOGLE_PROJECT_ID')
310
  k_location = os.getenv('GOOGLE_LOCATION')
311
+
312
+ k_huggingface = None
313
 
314
  k_mistral = os.getenv('MISTRAL_API_KEY')
315
  k_here = os.getenv('HERE_API_KEY')
 
318
  k_OPENAI_API_KEY = self.cfg_private['openai']['OPENAI_API_KEY']
319
  k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
320
 
321
+ k_huggingface = self.cfg_private['huggingface']['hf_token']
322
+
323
  k_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
324
  k_location = self.cfg_private['google']['GOOGLE_LOCATION']
325
  k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
 
345
  present_keys.append('Google OCR Handwriting (Invalid)')
346
  else:
347
  missing_keys.append('Google OCR')
348
+
349
+ # present_keys.append('[MODEL] TEST (Under Review)')
350
+
351
+ # HF key check
352
+ if self.has_API_key(k_huggingface):
353
+ is_valid = self.test_hf_token(k_huggingface)
354
+ if is_valid:
355
+ present_keys.append('Hugging Face Local LLMs (Valid)')
356
+ else:
357
+ present_keys.append('Hugging Face Local LLMs (Invalid)')
358
+ else:
359
+ missing_keys.append('Hugging Face Local LLMs')
360
+
361
+ # List of gated models to check access for
362
+ for model_id in self.HF_MODEL_LIST:
363
+ access_status = self.check_gated_model_access(model_id, k_huggingface)
364
+ if access_status == "valid":
365
+ present_keys.append(f'[MODEL] {model_id} (Valid)')
366
+ elif access_status == "under_review":
367
+ present_keys.append(f'[MODEL] {model_id} (Under Review)')
368
+ else:
369
+ present_keys.append(f'[MODEL] {model_id} (Invalid)')
370
+
371
 
372
 
373
  # OpenAI key check
 
381
  missing_keys.append('OpenAI')
382
 
383
  # Azure OpenAI key check
384
+ # if self.has_API_key(k_openai_azure):
385
+ # is_valid = self.check_azure_openai_api_key()
386
+ # if is_valid:
387
+ # present_keys.append('Azure OpenAI (Valid)')
388
+ # else:
389
+ # present_keys.append('Azure OpenAI (Invalid)')
390
+ # else:
391
+ # missing_keys.append('Azure OpenAI')
392
 
393
  # Google PALM2/Gemini key check
394
  if self.has_API_key(k_google_application_credentials) and self.has_API_key(k_project_id) and self.has_API_key(k_location): ##################
vouchervision/LLM_crewAI.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, os, json
2
+ import torch
3
+ from crewai import Agent, Task, Crew, Process
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain_openai import ChatOpenAI, OpenAI
6
+ from langchain.schema import HumanMessage
7
+ from langchain_core.output_parsers import JsonOutputParser
8
+ from langchain.output_parsers import RetryWithErrorOutputParser
9
+
10
+ class VoucherVisionWorkflow:
11
+ MODEL = 'gpt-4o'
12
+ SHARED_INSTRUCTIONS = """
13
+ instructions:
14
+ 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
15
+ 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
16
+ 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
17
+ 4. Duplicate dictionary fields are not allowed.
18
+ 5. Ensure all JSON keys are in camel case.
19
+ 6. Ensure new JSON field values follow sentence case capitalization.
20
+ 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
21
+ 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
22
+ 9. Only return a JSON dictionary represented as a string. You should not explain your answer.
23
+
24
+ JSON structure:
25
+ {"catalogNumber": "", "scientificName": "", "genus": "", "specificEpithet": "", "speciesNameAuthorship": "", "collectedBy": "", "collectorNumber": "", "identifiedBy": "", "verbatimCollectionDate": "", "collectionDate": "", "collectionDateEnd": "", "occurrenceRemarks": "", "habitat": "", "cultivated": "", "country": "", "stateProvince": "", "county": "", "locality": "", "verbatimCoordinates": "", "decimalLatitude": "", "decimalLongitude": "", "minimumElevationInMeters": "", "maximumElevationInMeters": "", "elevationUnits": ""}
26
+ """
27
+
28
+ EXPECTED_OUTPUT_STRUCTURE = """{
29
+ "JSON_OUTPUT": {
30
+ "catalogNumber": "", "scientificName": "", "genus": "", "specificEpithet": "",
31
+ "speciesNameAuthorship": "", "collectedBy": "", "collectorNumber": "",
32
+ "identifiedBy": "", "verbatimCollectionDate": "", "collectionDate": "",
33
+ "collectionDateEnd": "", "occurrenceRemarks": "", "habitat": "", "cultivated": "",
34
+ "country": "", "stateProvince": "", "county": "", "locality": "",
35
+ "verbatimCoordinates": "", "decimalLatitude": "", "decimalLongitude": "",
36
+ "minimumElevationInMeters": "", "maximumElevationInMeters": "", "elevationUnits": ""
37
+ },
38
+ "explanation": ""
39
+ }"""
40
+
41
+ def __init__(self, api_key, librarian_knowledge_path):
42
+ self.api_key = api_key
43
+ os.environ['OPENAI_API_KEY'] = self.api_key
44
+
45
+ self.librarian_knowledge = self.load_librarian_knowledge(librarian_knowledge_path)
46
+ self.worker_agent = self.create_worker_agent()
47
+ self.supervisor_agent = self.create_supervisor_agent()
48
+
49
+ def load_librarian_knowledge(self, path):
50
+ with open(path) as f:
51
+ return json.load(f)
52
+
53
+ def query_librarian(self, guideline_field):
54
+ print(f"query_librarian: {guideline_field}")
55
+ return self.librarian_knowledge.get(guideline_field, "Guideline not found.")
56
+
57
+ def create_worker_agent(self):
58
+ return Agent(
59
+ role="Transcriber and JSON Formatter",
60
+ goal="Transcribe product labels accurately and format them into a structured JSON dictionary. Only return a JSON dictionary.",
61
+ backstory="You're an AI trained to transcribe product labels and format them into JSON.",
62
+ verbose=True,
63
+ allow_delegation=False,
64
+ llm=ChatOpenAI(model=self.MODEL, openai_api_key=self.api_key),
65
+ prompt_instructions=self.SHARED_INSTRUCTIONS
66
+ )
67
+
68
+ def create_supervisor_agent(self):
69
+ class SupervisorAgent(Agent):
70
+ def correct_with_librarian(self, workflow, transcription, json_dict, guideline_field):
71
+ guideline = workflow.query_librarian(guideline_field)
72
+ corrected_transcription = self.correct(transcription, guideline)
73
+ corrected_json = self.correct_json(json_dict, guideline)
74
+ explanation = f"Corrected {json_dict} based on guideline {guideline_field}: {guideline}"
75
+ return corrected_transcription, {"JSON_OUTPUT": corrected_json, "explanation": explanation}
76
+
77
+ return SupervisorAgent(
78
+ role="Corrector",
79
+ goal="Ensure accurate transcriptions and JSON formatting according to specific guidelines. Compare the OCR text to the JSON dictionary and make any required corrections. Given your knowledge, make sure that the values in the JSON object make sense given the cumulative context of the OCR text. If you correct the provided JSON, then state the corrections. Otherwise say that the original worker was correct.",
80
+ backstory="You're an AI trained to correct transcriptions and JSON formatting, consulting the librarian for guidance.",
81
+ verbose=True,
82
+ allow_delegation=False,
83
+ llm=ChatOpenAI(model=self.MODEL, openai_api_key=self.api_key),
84
+ prompt_instructions=self.SHARED_INSTRUCTIONS
85
+ )
86
+
87
+ def extract_json_from_string(self, input_string):
88
+ json_pattern = re.compile(r'\{(?:[^{}]|(?R))*\}')
89
+ match = json_pattern.search(input_string)
90
+ if match:
91
+ return match.group(0)
92
+ return None
93
+
94
+ def extract_json_via_api(self, text):
95
+ self.api_key = self.api_key
96
+ extraction_prompt = f"I only need the JSON inside this text. Please return only the JSON object.\n\n{text}"
97
+ response = openai.ChatCompletion.create(
98
+ model=self.MODEL,
99
+ messages=[
100
+ {"role": "system", "content": extraction_prompt}
101
+ ]
102
+ )
103
+ return self.extract_json_from_string(response['choices'][0]['message']['content'])
104
+
105
+
106
+ def run_workflow(self, ocr_text):
107
+ openai_model = ChatOpenAI(api_key=self.api_key, model=self.MODEL)
108
+
109
+ self.worker_agent.llm = openai_model
110
+ self.supervisor_agent.llm = openai_model
111
+
112
+ transcription_and_formatting_task = Task(
113
+ description=f"Transcribe product label and format into JSON. OCR text: {ocr_text}",
114
+ agent=self.worker_agent,
115
+ inputs={"ocr_text": ocr_text},
116
+ expected_output=self.EXPECTED_OUTPUT_STRUCTURE
117
+ )
118
+
119
+ crew = Crew(
120
+ agents=[self.worker_agent],
121
+ tasks=[transcription_and_formatting_task],
122
+ verbose=True,
123
+ process=Process.sequential,
124
+ )
125
+
126
+ # Run the transcription and formatting task
127
+ transcription_and_formatting_result = transcription_and_formatting_task.execute()
128
+ print("Worker Output JSON:", transcription_and_formatting_result)
129
+
130
+ # Pass the worker's JSON output to the supervisor for correction
131
+ correction_task = Task(
132
+ description=f"Correct transcription and JSON format. OCR text: {ocr_text}",
133
+ agent=self.supervisor_agent,
134
+ inputs={"ocr_text": ocr_text, "json_dict": transcription_and_formatting_result},
135
+ expected_output=self.EXPECTED_OUTPUT_STRUCTURE,
136
+ workflow=self # Pass the workflow instance to the task
137
+ )
138
+
139
+ correction_result = correction_task.execute()
140
+
141
+ try:
142
+ corrected_json_with_explanation = json.loads(correction_result)
143
+ except json.JSONDecodeError:
144
+ # If initial parsing fails, make a call to OpenAI to extract only the JSON
145
+ corrected_json_string = self.extract_json_via_api(correction_result)
146
+ if not corrected_json_string:
147
+ raise ValueError("No JSON found in the supervisor's output.")
148
+ corrected_json_with_explanation = json.loads(corrected_json_string)
149
+
150
+ corrected_json = corrected_json_with_explanation["JSON_OUTPUT"]
151
+ explanation = corrected_json_with_explanation["explanation"]
152
+
153
+ print("Supervisor Corrected JSON:", corrected_json)
154
+ print("\nCorrection Explanation:", explanation)
155
+
156
+ return corrected_json, explanation
157
+
158
+ if __name__ == "__main__":
159
+ api_key = ""
160
+ librarian_knowledge_path = "D:/Dropbox/VoucherVision/vouchervision/librarian_knowledge.json"
161
+
162
+ ocr_text = "HERBARIUM OF MARYGROVE COLLEGE Name Carex scoparia V. condensa Fernald Locality Interlaken , Ind . Date 7/20/25 No ... ! Gerould Wilhelm & Laura Rericha \" Interlaken , \" was the site for many years of St. Joseph Novitiate , run by the Brothers of the Holy Cross . The buildings were on the west shore of Silver Lake , about 2 miles NE of Rolling Prairie , LaPorte Co. Indiana , ca. 41.688 \u00b0 N , 86.601 \u00b0 W Collector : Sister M. Vincent de Paul McGivney February 1 , 2011 THE UNIVERS Examined for the Flora of the Chicago Region OF 1817 MICH ! Ciscoparia SMVdeP University of Michigan Herbarium 1386297 copyright reserved cm Collector wortet 2010"
163
+ workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
164
+ workflow.run_workflow(ocr_text)
165
+
166
+ ocr_text = "CM 1 2 3 QUE : Mt.Jac.Cartier Parc de la Gasp\u00e9sie 13 Aug1988 CA Vogt on Solidago MCZ - ENT OXO Bombus vagans Smith ' det C.A. Vogt 1988 UIUC USDA BBDP 021159 00817079 "
167
+ workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
168
+ workflow.run_workflow(ocr_text)
169
+
170
+ ocr_text = "500 200 600 300 dots per inch ( optical ) 700 400 800 500 850 550 850 550 Golden Thread inches centimeters 500 200 600 300 dots per inch ( optical ) 11116 L * 39.12 65.43 49.87 44.26 b * 15.07 18.72 -22.29 22.85 4 -4.34 -13.80 3 13.24 18.11 2 1 5 9 7 11 ( A ) 10 -0.40 48.55 55.56 70.82 63.51 39.92 52.24 97.06 92.02 9.82 -33.43 34.26 11.81 -24.49 -0.35 59.60 -46.07 18.51 8 6 12 13 14 15 87.34 82.14 72.06 62.15 09.0- -0.75 -1.06 -1.19 -1.07 1.13 0.23 0.21 0.43 0.28 0.19 800 500 D50 Illuminant , 2 degree observer Density 0.04 0.09 0.15 0.22 Fam . Saurauiaceae J. G. Agardh Saurauia nepaulensis DC . S. Vietnam , Prov . Kontum . NW slopes of Ngoc Linh mountain system at 1200 m alt . near Ngoc Linh village . Secondary marshland with grasses and shrubs . Tree up to 5 m high . Flowers light rosy - pink . No VH 007 0.36 0.51 23.02.1995 International Botanical Expedition of the U.S.A. National Geographic Society ( grant No 5094-93 ) Participants : L. Averyanov , N.T. Ban , N. Q. Binh , A. Budantzev , L. Budantzev , N.T. Hiep , D.D. Huyen , P.K. Loc , N.X. Tam , G. Yakovlev BOTANICAL RESEARCH INSTITUTE OF TEXAS BRIT610199 Botanical Research Institute of Texas IMAGED 08 JUN 2021 FLORA OF VIETNAM "
171
+ workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
172
+ workflow.run_workflow(ocr_text)
173
+
174
+ ocr_text = "Russian - Vietnamese Tropical Centre Styrax argentifolius H.L. Li SOUTHERN VIETNAM Dak Lak prov . , Lak distr . , Bong Krang municip . Chu Yang Sin National Park 10 km S from Krong Kmar village River bank N 12 \u00b0 25 ' 24 \" E 108 \u00b0 21 ' 04 \" elev . 900 m Nuraliev M.S. No 1004 part of MW 0750340 29.05.2014 Materials of complex expedition in spring 2014 BOTANICAL RESEARCH INSTITUTE OF TEXAS ( BRIT ) Styrax benzoides Craib Det . by Peter W. Fritsch , September 2017 0 1 2 3 4 5 6 7 8 9 10 BOTANICAL RESEARCH INSTITUTE OF TEXAS BOTANICAL IMAGED RESEARCH INSTITUTE OF 10 JAN 2013 BRIT402114 copyright reserved cm BOTANICAL RESEARCH INSTITUTE OF TEXAS TM P CameraTrax.com BRIT . TEXAS "
175
+ workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
176
+ workflow.run_workflow(ocr_text)
vouchervision/LLM_local_MistralAI.py CHANGED
@@ -1,11 +1,13 @@
1
- import json, torch, transformers, gc
 
 
 
2
  from transformers import BitsAndBytesConfig
3
- from langchain.output_parsers import RetryWithErrorOutputParser
4
  from langchain.prompts import PromptTemplate
5
  from langchain_core.output_parsers import JsonOutputParser
6
  from huggingface_hub import hf_hub_download
7
- from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
8
-
9
  from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
10
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
11
 
@@ -14,7 +16,7 @@ Local Pipielines:
14
  https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
15
  '''
16
 
17
- class LocalMistralHandler:
18
  RETRY_DELAY = 2 # Wait 2 seconds before retrying
19
  MAX_RETRIES = 5 # Maximum number of retries
20
  STARTING_TEMP = 0.1
@@ -27,29 +29,22 @@ class LocalMistralHandler:
27
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
28
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
29
  self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
30
-
31
  self.logger = logger
32
  self.has_GPU = torch.cuda.is_available()
33
  self.monitor = SystemLoadMonitor(logger)
34
 
35
  self.model_name = model_name
36
  self.model_id = f"mistralai/{self.model_name}"
37
- name_parts = self.model_name.split('-')
38
-
39
- self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
40
-
41
 
42
  self.JSON_dict_structure = JSON_dict_structure
43
  self.starting_temp = float(self.STARTING_TEMP)
44
  self.temp_increment = float(0.2)
45
- self.adjust_temp = self.starting_temp
46
-
47
- system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
48
- template = """
49
- <s>[INST]{}[/INST]</s>
50
 
51
- [INST]{}[/INST]
52
- """.format(system_prompt, "{query}")
53
 
54
  # Create a prompt from the template so we can use it with Langchain
55
  self.prompt = PromptTemplate(template=template, input_variables=["query"])
@@ -59,45 +54,22 @@ class LocalMistralHandler:
59
 
60
  self._set_config()
61
 
62
-
63
- # def _clear_VRAM(self):
64
- # # Clear CUDA cache if it's being used
65
- # if self.has_GPU:
66
- # self.local_model = None
67
- # self.local_model_pipeline = None
68
- # del self.local_model
69
- # del self.local_model_pipeline
70
- # gc.collect() # Explicitly invoke garbage collector
71
- # torch.cuda.empty_cache()
72
- # else:
73
- # self.local_model_pipeline = None
74
- # self.local_model = None
75
- # del self.local_model_pipeline
76
- # del self.local_model
77
- # gc.collect() # Explicitly invoke garbage collector
78
-
79
-
80
  def _set_config(self):
81
- # self._clear_VRAM()
82
- self.config = {'max_new_tokens': 1024,
83
- 'temperature': self.starting_temp,
84
- 'seed': 2023,
85
- 'top_p': 1,
86
- 'top_k': 40,
87
- 'do_sample': True,
88
- 'n_ctx':4096,
89
-
90
- # Activate 4-bit precision base model loading
91
- 'use_4bit': True,
92
- # Compute dtype for 4-bit base models
93
- 'bnb_4bit_compute_dtype': "float16",
94
- # Quantization type (fp4 or nf4)
95
- 'bnb_4bit_quant_type': "nf4",
96
- # Activate nested quantization for 4-bit base models (double quantization)
97
- 'use_nested_quant': False,
98
- }
99
-
100
- compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
101
 
102
  self.bnb_config = BitsAndBytesConfig(
103
  load_in_4bit=self.config.get('use_4bit'),
@@ -106,123 +78,102 @@ class LocalMistralHandler:
106
  bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
107
  )
108
 
109
- # Check GPU compatibility with bfloat16
110
  if compute_dtype == torch.float16 and self.config.get('use_4bit'):
111
  major, _ = torch.cuda.get_device_capability()
112
- if major >= 8:
113
- # print("=" * 80)
114
- # print("Your GPU supports bfloat16: accelerate training with bf16=True")
115
- # print("=" * 80)
116
- self.b_float_opt = torch.bfloat16
117
-
118
- else:
119
- self.b_float_opt = torch.float16
120
  self._build_model_chain_parser()
121
-
122
 
123
  def _adjust_config(self):
124
  new_temp = self.adjust_temp + self.temp_increment
125
- if self.json_report:
126
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
127
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
128
  self.adjust_temp += self.temp_increment
129
 
130
-
131
  def _reset_config(self):
132
- if self.json_report:
133
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
134
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
135
  self.adjust_temp = self.starting_temp
136
-
137
 
138
  def _build_model_chain_parser(self):
139
- self.local_model_pipeline = transformers.pipeline("text-generation",
140
- model=self.model_id,
141
- max_new_tokens=self.config.get('max_new_tokens'),
142
- top_k=self.config.get('top_k'),
143
- top_p=self.config.get('top_p'),
144
- do_sample=self.config.get('do_sample'),
145
- model_kwargs={"torch_dtype": self.b_float_opt,
146
- "load_in_4bit": True,
147
- "quantization_config": self.bnb_config})
148
  self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
 
149
  # Set up the retry parser with the runnable
150
- self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
151
  # Create an llm chain with LLM and prompt
152
- self.chain = self.prompt | self.local_model # LCEL
153
-
154
 
155
  def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
156
- _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
157
  self.json_report = json_report
158
  if self.json_report:
159
  self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
160
  self.monitor.start_monitoring_usage()
161
-
162
  nt_in = 0
163
  nt_out = 0
164
 
165
- ind = 0
166
- while ind < self.MAX_RETRIES:
167
- ind += 1
168
  try:
169
- # Dynamically set the temperature for this specific request
170
  model_kwargs = {"temperature": self.adjust_temp}
171
-
172
- # Invoke the chain to generate prompt text
173
  results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
174
 
175
- # Use retry_parser to parse the response with retry logic
176
  output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
177
 
178
  if output is None:
179
  self.logger.error(f'Failed to extract JSON from:\n{results}')
180
  self._adjust_config()
181
  del results
182
-
183
  else:
184
  nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
185
  nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
186
 
187
  output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
188
-
189
  if output is None:
190
- self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
191
  self._adjust_config()
192
  else:
193
- self.monitor.stop_inference_timer() # Starts tool timer too
194
-
195
  if self.json_report:
196
  self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
197
- output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
 
 
198
 
199
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
200
 
201
- self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
202
 
203
- usage_report = self.monitor.stop_monitoring_report_usage()
204
 
205
- if self.adjust_temp != self.starting_temp:
206
  self._reset_config()
207
 
208
  if self.json_report:
209
  self.json_report.set_text(text_main=f'LLM call successful')
210
  del results
211
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
212
-
213
  except Exception as e:
214
  self.logger.error(f'{e}')
215
- self._adjust_config()
216
-
217
- self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
218
  if self.json_report:
219
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
220
 
221
- self.monitor.stop_inference_timer() # Starts tool timer too
222
- usage_report = self.monitor.stop_monitoring_report_usage()
223
  if self.json_report:
224
  self.json_report.set_text(text_main=f'LLM call failed')
225
 
226
  self._reset_config()
227
- return None, nt_in, nt_out, None, None, usage_report
228
-
 
1
+ import json, os
2
+ import torch
3
+ import transformers
4
+ import gc
5
  from transformers import BitsAndBytesConfig
6
+ from langchain.output_parsers.retry import RetryOutputParser
7
  from langchain.prompts import PromptTemplate
8
  from langchain_core.output_parsers import JsonOutputParser
9
  from huggingface_hub import hf_hub_download
10
+ from langchain_huggingface import HuggingFacePipeline
 
11
  from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
12
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
13
 
 
16
  https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
17
  '''
18
 
19
+ class LocalMistralHandler:
20
  RETRY_DELAY = 2 # Wait 2 seconds before retrying
21
  MAX_RETRIES = 5 # Maximum number of retries
22
  STARTING_TEMP = 0.1
 
29
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
30
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
31
  self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
32
+
33
  self.logger = logger
34
  self.has_GPU = torch.cuda.is_available()
35
  self.monitor = SystemLoadMonitor(logger)
36
 
37
  self.model_name = model_name
38
  self.model_id = f"mistralai/{self.model_name}"
39
+ self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model", filename="config.json", use_auth_token=os.getenv("HUGGING_FACE_KEY"))
 
 
 
40
 
41
  self.JSON_dict_structure = JSON_dict_structure
42
  self.starting_temp = float(self.STARTING_TEMP)
43
  self.temp_increment = float(0.2)
44
+ self.adjust_temp = self.starting_temp
 
 
 
 
45
 
46
+ system_prompt = "You are a helpful AI assistant who answers queries by returning a JSON dictionary as specified by the user."
47
+ template = "<s>[INST]{}[/INST]</s>[INST]{}[/INST]".format(system_prompt, "{query}")
48
 
49
  # Create a prompt from the template so we can use it with Langchain
50
  self.prompt = PromptTemplate(template=template, input_variables=["query"])
 
54
 
55
  self._set_config()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def _set_config(self):
58
+ self.config = {
59
+ 'max_new_tokens': 1024,
60
+ 'temperature': self.starting_temp,
61
+ 'seed': 2023,
62
+ 'top_p': 1,
63
+ 'top_k': 40,
64
+ 'do_sample': True,
65
+ 'n_ctx': 4096,
66
+ 'use_4bit': True,
67
+ 'bnb_4bit_compute_dtype': "float16",
68
+ 'bnb_4bit_quant_type': "nf4",
69
+ 'use_nested_quant': False,
70
+ }
71
+
72
+ compute_dtype = getattr(torch, self.config.get('bnb_4bit_compute_dtype'))
 
 
 
 
 
73
 
74
  self.bnb_config = BitsAndBytesConfig(
75
  load_in_4bit=self.config.get('use_4bit'),
 
78
  bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
79
  )
80
 
 
81
  if compute_dtype == torch.float16 and self.config.get('use_4bit'):
82
  major, _ = torch.cuda.get_device_capability()
83
+ self.b_float_opt = torch.bfloat16 if major >= 8 else torch.float16
84
+
 
 
 
 
 
 
85
  self._build_model_chain_parser()
 
86
 
87
  def _adjust_config(self):
88
  new_temp = self.adjust_temp + self.temp_increment
 
 
89
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
90
  self.adjust_temp += self.temp_increment
91
 
 
92
  def _reset_config(self):
 
 
93
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
94
  self.adjust_temp = self.starting_temp
 
95
 
96
  def _build_model_chain_parser(self):
97
+ self.local_model_pipeline = transformers.pipeline(
98
+ "text-generation",
99
+ model=self.model_id,
100
+ max_new_tokens=self.config.get('max_new_tokens'),
101
+ top_k=self.config.get('top_k'),
102
+ top_p=self.config.get('top_p'),
103
+ do_sample=self.config.get('do_sample'),
104
+ model_kwargs={"torch_dtype": self.b_float_opt, "quantization_config": self.bnb_config},
105
+ )
106
  self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
107
+
108
  # Set up the retry parser with the runnable
109
+ self.retry_parser = RetryOutputParser(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
110
  # Create an llm chain with LLM and prompt
111
+ self.chain = self.prompt | self.local_model
 
112
 
113
  def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
114
+ json_file_path_wiki, txt_file_path_ind_prompt = paths[-2:]
115
  self.json_report = json_report
116
  if self.json_report:
117
  self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
118
  self.monitor.start_monitoring_usage()
119
+
120
  nt_in = 0
121
  nt_out = 0
122
 
123
+ for ind in range(self.MAX_RETRIES):
 
 
124
  try:
 
125
  model_kwargs = {"temperature": self.adjust_temp}
 
 
126
  results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
127
 
 
128
  output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
129
 
130
  if output is None:
131
  self.logger.error(f'Failed to extract JSON from:\n{results}')
132
  self._adjust_config()
133
  del results
 
134
  else:
135
  nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
136
  nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
137
 
138
  output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
139
+
140
  if output is None:
141
+ self.logger.error(f'[Attempt {ind + 1}] Failed to extract JSON from:\n{results}')
142
  self._adjust_config()
143
  else:
144
+ self.monitor.stop_inference_timer() # Starts tool timer too
145
+
146
  if self.json_report:
147
  self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
148
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(
149
+ output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki
150
+ )
151
 
152
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
153
 
154
+ self.logger.info(f"Formatted JSON:\n{json.dumps(output, indent=4)}")
155
 
156
+ usage_report = self.monitor.stop_monitoring_report_usage()
157
 
158
+ if self.adjust_temp != self.starting_temp:
159
  self._reset_config()
160
 
161
  if self.json_report:
162
  self.json_report.set_text(text_main=f'LLM call successful')
163
  del results
164
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
 
165
  except Exception as e:
166
  self.logger.error(f'{e}')
167
+ self._adjust_config()
168
+
169
+ self.logger.info(f"Failed to extract valid JSON after [{self.MAX_RETRIES}] attempts")
170
  if self.json_report:
171
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{self.MAX_RETRIES}] attempts')
172
 
173
+ self.monitor.stop_inference_timer() # Starts tool timer too
174
+ usage_report = self.monitor.stop_monitoring_report_usage()
175
  if self.json_report:
176
  self.json_report.set_text(text_main=f'LLM call failed')
177
 
178
  self._reset_config()
179
+ return None, nt_in, nt_out, None, None, usage_report
 
vouchervision/LLM_local_custom_fine_tune.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, json, yaml, torch
2
+ from peft import AutoPeftModelForCausalLM
3
+ from transformers import AutoTokenizer
4
+
5
+ import json, torch, transformers, gc
6
+ from transformers import BitsAndBytesConfig
7
+ from langchain.output_parsers.retry import RetryOutputParser
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_core.output_parsers import JsonOutputParser
10
+ from huggingface_hub import hf_hub_download
11
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
12
+
13
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
14
+ from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
15
+
16
+ # MODEL_NAME = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
17
+ # sltp_version = 'HLT_MICH_Angiospermae_SLTPvA_v1-0_medium__OCR-C25-L25-E50-R05'
18
+ # LORA = "phyloforfun/mistral-7b-instruct-v2-bnb-4bit__HLT_MICH_Angiospermae_SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05"
19
+
20
+ TEXT = "HERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. Mincral Springs edge wet subdural woods 1927 TX 11 Flowers pink UNIVERSIT HERBARIUM MICHIGAN MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 26 1965 cm "
21
+ PARENT_MODEL = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
22
+
23
+ class LocalFineTuneHandler:
24
+ RETRY_DELAY = 2 # Wait 2 seconds before retrying
25
+ MAX_RETRIES = 5 # Maximum number of retries
26
+ STARTING_TEMP = 0.001
27
+ TOKENIZER_NAME = None
28
+ VENDOR = 'mistral'
29
+ MAX_GPU_MONITORING_INTERVAL = 2 # seconds
30
+
31
+
32
+
33
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation=None):
34
+ # self.model_id = f"phyloforfun/{self.model_name}"
35
+ # model_name = LORA #######################################################
36
+
37
+ # self.JSON_dict_structure = JSON_dict_structure
38
+ # self.JSON_dict_structure_str = json.dumps(self.JSON_dict_structure, sort_keys=False, indent=4)
39
+
40
+ self.JSON_dict_structure_str = """{"catalogNumber": "", "scientificName": "", "genus": "", "specificEpithet": "", "scientificNameAuthorship": "", "collector": "", "recordNumber": "", "identifiedBy": "", "verbatimCollectionDate": "", "collectionDate": "", "occurrenceRemarks": "", "habitat": "", "locality": "", "country": "", "stateProvince": "", "county": "", "municipality": "", "verbatimCoordinates": "", "decimalLatitude": "", "decimalLongitude": "", "minimumElevationInMeters": "", "maximumElevationInMeters": ""}"""
41
+
42
+
43
+ self.cfg = cfg
44
+ self.print_output = True
45
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
46
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
47
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
48
+
49
+ self.logger = logger
50
+
51
+ self.has_GPU = torch.cuda.is_available()
52
+ if self.has_GPU:
53
+ self.device = "cuda"
54
+ else:
55
+ self.device = "cpu"
56
+
57
+ self.monitor = SystemLoadMonitor(logger)
58
+
59
+ self.model_name = model_name.split("/")[1]
60
+ self.model_id = model_name
61
+
62
+ # self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
63
+
64
+
65
+ self.starting_temp = float(self.STARTING_TEMP)
66
+ self.temp_increment = float(0.2)
67
+ self.adjust_temp = self.starting_temp
68
+
69
+ self.load_in_4bit = False
70
+
71
+ self.parser = JsonOutputParser()
72
+
73
+ self._load_model()
74
+ self._create_prompt()
75
+ self._set_config()
76
+ self._build_model_chain_parser()
77
+
78
+ def _set_config(self):
79
+ # self._clear_VRAM()
80
+ self.config = {'max_new_tokens': 1024,
81
+ 'temperature': self.starting_temp,
82
+ 'seed': 2023,
83
+ 'top_p': 1,
84
+ # 'top_k': 1,
85
+ # 'top_k': 40,
86
+ 'do_sample': False,
87
+ 'n_ctx':4096,
88
+
89
+ # Activate 4-bit precision base model loading
90
+ # 'use_4bit': True,
91
+ # # Compute dtype for 4-bit base models
92
+ # 'bnb_4bit_compute_dtype': "float16",
93
+ # # Quantization type (fp4 or nf4)
94
+ # 'bnb_4bit_quant_type': "nf4",
95
+ # # Activate nested quantization for 4-bit base models (double quantization)
96
+ # 'use_nested_quant': False,
97
+ }
98
+
99
+ def _adjust_config(self):
100
+ new_temp = self.adjust_temp + self.temp_increment
101
+ if self.json_report:
102
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
103
+ self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
104
+ self.adjust_temp += self.temp_increment
105
+
106
+
107
+ def _reset_config(self):
108
+ if self.json_report:
109
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
110
+ self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
111
+ self.adjust_temp = self.starting_temp
112
+
113
+
114
+ def _load_model(self):
115
+ self.model = AutoPeftModelForCausalLM.from_pretrained(
116
+ pretrained_model_name_or_path=self.model_id, # YOUR MODEL YOU USED FOR TRAINING
117
+ load_in_4bit = self.load_in_4bit,
118
+ low_cpu_mem_usage=True,
119
+
120
+ ).to(self.device)
121
+
122
+ self.tokenizer = AutoTokenizer.from_pretrained(PARENT_MODEL)
123
+ self.eos_token_id = self.tokenizer.eos_token_id
124
+
125
+
126
+ # def _build_model_chain_parser(self):
127
+ # self.local_model_pipeline = transformers.pipeline("text-generation",
128
+ # model=self.model_id,
129
+ # max_new_tokens=self.config.get('max_new_tokens'),
130
+ # # top_k=self.config.get('top_k'),
131
+ # top_p=self.config.get('top_p'),
132
+ # do_sample=self.config.get('do_sample'),
133
+ # model_kwargs={"load_in_4bit": self.load_in_4bit})
134
+ # self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
135
+ # # Set up the retry parser with the runnable
136
+ # # self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
137
+ # self.retry_parser = RetryOutputParser(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
138
+
139
+ # # Create an llm chain with LLM and prompt
140
+ # self.chain = self.prompt | self.local_model # LCEL
141
+ def _build_model_chain_parser(self):
142
+ self.local_model_pipeline = transformers.pipeline(
143
+ "text-generation",
144
+ model=self.model_id,
145
+ max_new_tokens=self.config.get('max_new_tokens'),
146
+ top_k=self.config.get('top_k', None),
147
+ top_p=self.config.get('top_p'),
148
+ do_sample=self.config.get('do_sample'),
149
+ model_kwargs={"load_in_4bit": self.load_in_4bit},
150
+ )
151
+ self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
152
+ self.retry_parser = RetryOutputParser(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
153
+
154
+
155
+
156
+ def _create_prompt(self):
157
+ self.alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
158
+
159
+ ### Instruction:
160
+ {}
161
+
162
+ ### Input:
163
+ {}
164
+
165
+ ### Response:
166
+ {}"""
167
+
168
+ self.template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
169
+
170
+ ### Instruction:
171
+ {}
172
+
173
+ ### Input:
174
+ {}
175
+
176
+ ### Response:
177
+ {}""".format("{instructions}", "{OCR_text}", "{empty}")
178
+
179
+ self.instructions_text = """Refactor the unstructured text into a valid JSON dictionary. The key names follow the Darwin Core Archive Standard. If a key lacks content, then insert an empty string. Fill in the following JSON structure as required: """
180
+ self.instructions_json = self.JSON_dict_structure_str.replace("\n ", " ").strip().replace("\n", " ")
181
+ self.instructions = ''.join([self.instructions_text, self.instructions_json])
182
+
183
+
184
+ # Create a prompt from the template so we can use it with Langchain
185
+ self.prompt = PromptTemplate(template=self.template, input_variables=["instructions", "OCR_text", "empty"])
186
+
187
+ # Set up a parser
188
+ self.parser = JsonOutputParser()
189
+
190
+
191
+ def extract_json(self, response_text):
192
+ # Assuming the response is a list with a single string entry
193
+ # response_text = response[0]
194
+
195
+ response_pattern = re.compile(r'### Response:(.*)', re.DOTALL)
196
+ response_match = response_pattern.search(response_text)
197
+ if not response_match:
198
+ raise ValueError("No '### Response:' section found in the provided text")
199
+
200
+ response_text = response_match.group(1)
201
+
202
+ # Use a regular expression to find JSON objects in the response text
203
+ json_objects = re.findall(r'\{.*?\}', response_text, re.DOTALL)
204
+
205
+ if json_objects:
206
+ # Assuming you want the first JSON object if there are multiple
207
+ json_str = json_objects[0]
208
+ # Convert the JSON string to a Python dictionary
209
+ json_dict = json.loads(json_str)
210
+ return json_str, json_dict
211
+ else:
212
+ raise ValueError("No JSON object found in the '### Response:' section")
213
+
214
+
215
+ def call_llm_local_custom_fine_tune(self, OCR_text, json_report, paths):
216
+ _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
217
+ self.json_report = json_report
218
+ if self.json_report:
219
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
220
+ self.monitor.start_monitoring_usage()
221
+
222
+ nt_in = 0
223
+ nt_out = 0
224
+
225
+ self.inputs = self.tokenizer(
226
+ [
227
+ self.alpaca_prompt.format(
228
+ self.instructions, # instruction
229
+ OCR_text, # input
230
+ "", # output - leave this blank for generation!
231
+ )
232
+ ], return_tensors = "pt").to(self.device)
233
+
234
+ ind = 0
235
+ while ind < self.MAX_RETRIES:
236
+ ind += 1
237
+ try:
238
+ # Fancy
239
+ # Dynamically set the temperature for this specific request
240
+ model_kwargs = {"temperature": self.adjust_temp}
241
+
242
+ # Invoke the chain to generate prompt text
243
+ # results = self.chain.invoke({"instructions": self.instructions, "OCR_text": OCR_text, "empty": "", "model_kwargs": model_kwargs})
244
+
245
+ # Use retry_parser to parse the response with retry logic
246
+ # output = self.retry_parser.parse_with_prompt(results, prompt_value=OCR_text)
247
+ results = self.local_model.invoke(OCR_text)
248
+ output = self.retry_parser.parse_with_prompt(results, prompt_value=OCR_text)
249
+
250
+
251
+ # Should work:
252
+ # output = self.model.generate(**self.inputs, eos_token_id=self.eos_token_id, max_new_tokens=512) # Adjust max_length as needed
253
+
254
+ # Decode the generated text
255
+ # generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
256
+
257
+ # json_str, json_dict = self.extract_json(generated_text)
258
+ if self.print_output:
259
+ # print("\nJSON String:")
260
+ # print(json_str)
261
+ print("\nJSON Dictionary:")
262
+ print(output)
263
+
264
+
265
+
266
+ if output is None:
267
+ self.logger.error(f'Failed to extract JSON from:\n{results}')
268
+ self._adjust_config()
269
+ del results
270
+
271
+ else:
272
+ nt_in = count_tokens(self.instructions+OCR_text, self.VENDOR, self.TOKENIZER_NAME)
273
+ nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
274
+
275
+ output = validate_and_align_JSON_keys_with_template(output, json.loads(self.JSON_dict_structure_str))
276
+
277
+ if output is None:
278
+ self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
279
+ self._adjust_config()
280
+ else:
281
+ self.monitor.stop_inference_timer() # Starts tool timer too
282
+
283
+ if self.json_report:
284
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
285
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
286
+
287
+ save_individual_prompt(sanitize_prompt(self.instructions+OCR_text), txt_file_path_ind_prompt)
288
+
289
+ self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
290
+
291
+ usage_report = self.monitor.stop_monitoring_report_usage()
292
+
293
+ if self.adjust_temp != self.starting_temp:
294
+ self._reset_config()
295
+
296
+ if self.json_report:
297
+ self.json_report.set_text(text_main=f'LLM call successful')
298
+ del results
299
+ return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
300
+
301
+ except Exception as e:
302
+ self.logger.error(f'{e}')
303
+
304
+
305
+ self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
306
+ if self.json_report:
307
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
308
+
309
+ self.monitor.stop_inference_timer() # Starts tool timer too
310
+ usage_report = self.monitor.stop_monitoring_report_usage()
311
+ if self.json_report:
312
+ self.json_report.set_text(text_main=f'LLM call failed')
313
+
314
+ return None, nt_in, nt_out, None, None, usage_report
315
+
316
+
317
+
318
+ # # Create a prompt from the template so we can use it with Langchain
319
+ # self.prompt = PromptTemplate(template=template, input_variables=["query"])
320
+
321
+ # # Set up a parser
322
+ # self.parser = JsonOutputParser()
323
+
324
+
325
+
326
+
327
+
328
+
329
+
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
+
338
+
339
+
340
+
341
+
342
+
343
+ model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
344
+ sltp_version = 'HLT_MICH_Angiospermae_SLTPvA_v1-0_medium__OCR-C25-L25-E50-R05'
345
+ lora_name = "phyloforfun/mistral-7b-instruct-v2-bnb-4bit__HLT_MICH_Angiospermae_SLTPvA_v1-0_medium__OCR-C25-L25-E50-R05"
346
+
347
+ OCR_test = "HERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. Mincral Springs edge wet subdural woods 1927 TX 11 Flowers pink UNIVERSIT HERBARIUM MICHIGAN MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 26 1965 cm "
348
+
349
+
350
+
351
+
352
+
353
+ # model.merge_and_unload()
354
+
355
+
356
+
357
+ # Generate the output
358
+
vouchervision/OCR_Florence_2.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random, os
2
+ from PIL import Image
3
+ import copy
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import numpy as np
8
+ import warnings
9
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
10
+ from vouchervision.utils_LLM import SystemLoadMonitor
11
+
12
+ warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
13
+
14
+ class FlorenceOCR:
15
+ def __init__(self, logger, model_id='microsoft/Florence-2-large'):
16
+ self.MAX_TOKENS = 1024
17
+ self.logger = logger
18
+ self.model_id = model_id
19
+
20
+ self.monitor = SystemLoadMonitor(logger)
21
+
22
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
23
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
24
+
25
+ # self.model_id_clean = "mistralai/Mistral-7B-v0.3"
26
+ self.model_id_clean = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
27
+ self.tokenizer_clean = AutoTokenizer.from_pretrained(self.model_id_clean)
28
+ self.model_clean = AutoModelForCausalLM.from_pretrained(self.model_id_clean)
29
+
30
+
31
+ def ocr_florence(self, image, task_prompt='<OCR>', text_input=None):
32
+ self.monitor.start_monitoring_usage()
33
+
34
+ # Open image if a path is provided
35
+ if isinstance(image, str):
36
+ image = Image.open(image)
37
+
38
+ if text_input is None:
39
+ prompt = task_prompt
40
+ else:
41
+ prompt = task_prompt + text_input
42
+
43
+ inputs = self.processor(text=prompt, images=image, return_tensors="pt")
44
+
45
+ # Move input_ids and pixel_values to the same device as the model
46
+ inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
47
+
48
+ generated_ids = self.model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ pixel_values=inputs["pixel_values"],
51
+ max_new_tokens=self.MAX_TOKENS,
52
+ early_stopping=False,
53
+ do_sample=False,
54
+ num_beams=3,
55
+ )
56
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
57
+ parsed_answer_dirty = self.processor.post_process_generation(
58
+ generated_text,
59
+ task=task_prompt,
60
+ image_size=(image.width, image.height)
61
+ )
62
+
63
+ inputs = self.tokenizer_clean(f"Insert spaces into this text to make all the words valid. This text contains scientific names of plants, locations, habitat, coordinate words: {parsed_answer_dirty[task_prompt]}", return_tensors="pt")
64
+ inputs = {key: value.to(self.model_clean.device) for key, value in inputs.items()}
65
+
66
+ outputs = self.model_clean.generate(**inputs, max_new_tokens=self.MAX_TOKENS)
67
+ parsed_answer = self.tokenizer_clean.decode(outputs[0], skip_special_tokens=True)
68
+ print(parsed_answer_dirty)
69
+ print(parsed_answer)
70
+
71
+ self.monitor.stop_inference_timer() # Starts tool timer too
72
+ usage_report = self.monitor.stop_monitoring_report_usage()
73
+
74
+ return parsed_answer, parsed_answer_dirty[task_prompt], parsed_answer_dirty, usage_report
75
+
76
+
77
+ def main():
78
+ img_path = '/home/brlab/Downloads/gem_2024_06_26__02-26-02/Cropped_Images/By_Class/label/1.jpg'
79
+ # img = 'D:/D_Desktop/BR_1839468565_Ochnaceae_Campylospermum_reticulatum_label.jpg'
80
+
81
+ image = Image.open(img_path)
82
+
83
+ ocr = FlorenceOCR(logger = None)
84
+ results_text, results, usage_report = ocr.ocr_florence(image, task_prompt='<OCR>', text_input=None)
85
+ print(results_text)
86
+
87
+ if __name__ == '__main__':
88
+ main()
vouchervision/OCR_google_cloud_vision (DESKTOP-548UDCR's conflicted copy 2024-06-15).py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, sys, inspect, statistics, json, cv2
2
+ from statistics import mean
3
+ # from google.cloud import vision, storage
4
+ from google.cloud import vision
5
+ from google.cloud import vision_v1p3beta1 as vision_beta
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import colorsys
8
+ from tqdm import tqdm
9
+ from google.oauth2 import service_account
10
+
11
+ ### LLaVA should only be installed if the user will actually use it.
12
+ ### It requires the most recent pytorch/Python and can mess with older systems
13
+
14
+
15
+ '''
16
+ @misc{li2021trocr,
17
+ title={TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models},
18
+ author={Minghao Li and Tengchao Lv and Lei Cui and Yijuan Lu and Dinei Florencio and Cha Zhang and Zhoujun Li and Furu Wei},
19
+ year={2021},
20
+ eprint={2109.10282},
21
+ archivePrefix={arXiv},
22
+ primaryClass={cs.CL}
23
+ }
24
+ @inproceedings{baek2019character,
25
+ title={Character Region Awareness for Text Detection},
26
+ author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk},
27
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
28
+ pages={9365--9374},
29
+ year={2019}
30
+ }
31
+ '''
32
+
33
+ class OCREngine:
34
+
35
+ BBOX_COLOR = "black"
36
+
37
+ def __init__(self, logger, json_report, dir_home, is_hf, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
38
+ self.is_hf = is_hf
39
+ self.logger = logger
40
+
41
+ self.json_report = json_report
42
+
43
+ self.path = path
44
+ self.cfg = cfg
45
+ self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
46
+ self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
47
+ self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
48
+ self.dir_home = dir_home
49
+
50
+ # Initialize TrOCR components
51
+ self.trOCR_model_version = trOCR_model_version
52
+ self.trOCR_processor = trOCR_processor
53
+ self.trOCR_model = trOCR_model
54
+ self.device = device
55
+
56
+ self.hand_cleaned_text = None
57
+ self.hand_organized_text = None
58
+ self.hand_bounds = None
59
+ self.hand_bounds_word = None
60
+ self.hand_bounds_flat = None
61
+ self.hand_text_to_box_mapping = None
62
+ self.hand_height = None
63
+ self.hand_confidences = None
64
+ self.hand_characters = None
65
+
66
+ self.normal_cleaned_text = None
67
+ self.normal_organized_text = None
68
+ self.normal_bounds = None
69
+ self.normal_bounds_word = None
70
+ self.normal_text_to_box_mapping = None
71
+ self.normal_bounds_flat = None
72
+ self.normal_height = None
73
+ self.normal_confidences = None
74
+ self.normal_characters = None
75
+
76
+ self.trOCR_texts = None
77
+ self.trOCR_text_to_box_mapping = None
78
+ self.trOCR_bounds_flat = None
79
+ self.trOCR_height = None
80
+ self.trOCR_confidences = None
81
+ self.trOCR_characters = None
82
+ self.set_client()
83
+ self.init_craft()
84
+
85
+ self.multimodal_prompt = """I need you to transcribe all of the text in this image.
86
+ Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
87
+ self.init_llava()
88
+
89
+
90
+ def set_client(self):
91
+ if self.is_hf:
92
+ self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
93
+ self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
94
+ else:
95
+ self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
96
+ self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
97
+
98
+
99
+ def get_google_credentials(self):
100
+ creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
101
+ credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
102
+ return credentials
103
+
104
+ def init_craft(self):
105
+ if 'CRAFT' in self.OCR_option:
106
+ from craft_text_detector import load_craftnet_model, load_refinenet_model
107
+
108
+ try:
109
+ self.refine_net = load_refinenet_model(cuda=True)
110
+ self.use_cuda = True
111
+ except:
112
+ self.refine_net = load_refinenet_model(cuda=False)
113
+ self.use_cuda = False
114
+
115
+ if self.use_cuda:
116
+ self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=True)
117
+ else:
118
+ self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
119
+
120
+ def init_llava(self):
121
+ if 'LLaVA' in self.OCR_option:
122
+ from vouchervision.OCR_llava import OCRllava
123
+
124
+ self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
125
+ self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
126
+
127
+ if self.json_report:
128
+ self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
129
+
130
+ if self.model_quant == '4bit':
131
+ use_4bit = True
132
+ elif self.model_quant == 'full':
133
+ use_4bit = False
134
+ else:
135
+ self.logger.info(f"Provided model quantization invlid. Using 4bit.")
136
+ use_4bit = True
137
+
138
+ self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
139
+
140
+ def init_gemini_vision(self):
141
+ pass
142
+
143
+ def init_gpt4_vision(self):
144
+ pass
145
+
146
+
147
+ def detect_text_craft(self):
148
+ from craft_text_detector import read_image, get_prediction
149
+
150
+ # Perform prediction using CRAFT
151
+ image = read_image(self.path)
152
+
153
+ link_threshold = 0.85
154
+ text_threshold = 0.4
155
+ low_text = 0.4
156
+
157
+ if self.use_cuda:
158
+ self.prediction_result = get_prediction(
159
+ image=image,
160
+ craft_net=self.craft_net,
161
+ refine_net=self.refine_net,
162
+ text_threshold=text_threshold,
163
+ link_threshold=link_threshold,
164
+ low_text=low_text,
165
+ cuda=True,
166
+ long_size=1280
167
+ )
168
+ else:
169
+ self.prediction_result = get_prediction(
170
+ image=image,
171
+ craft_net=self.craft_net,
172
+ refine_net=self.refine_net,
173
+ text_threshold=text_threshold,
174
+ link_threshold=link_threshold,
175
+ low_text=low_text,
176
+ cuda=False,
177
+ long_size=1280
178
+ )
179
+
180
+ # Initialize metadata structures
181
+ bounds = []
182
+ bounds_word = [] # CRAFT gives bounds for text regions, not individual words
183
+ text_to_box_mapping = []
184
+ bounds_flat = []
185
+ height_flat = []
186
+ confidences = [] # CRAFT does not provide confidences per character, so this might be uniformly set or estimated
187
+ characters = [] # Simulating as CRAFT doesn't provide character-level details
188
+ organized_text = ""
189
+
190
+ total_b = len(self.prediction_result["boxes"])
191
+ i=0
192
+ # Process each detected text region
193
+ for box in self.prediction_result["boxes"]:
194
+ i+=1
195
+ if self.json_report:
196
+ self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
197
+
198
+ vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
199
+
200
+ # Simulate a mapping for the whole detected region as a word
201
+ text_to_box_mapping.append({
202
+ "vertices": vertices,
203
+ "text": "detected_text" # Placeholder, as CRAFT does not provide the text content directly
204
+ })
205
+
206
+ # Assuming each box is a word for the sake of this example
207
+ bounds_word.append({"vertices": vertices})
208
+
209
+ # For simplicity, we're not dividing text regions into characters as CRAFT doesn't provide this
210
+ # Instead, we create a single large 'character' per detected region
211
+ bounds.append({"vertices": vertices})
212
+
213
+ # Simulate flat bounds and height for each detected region
214
+ x_positions = [vertex["x"] for vertex in vertices]
215
+ y_positions = [vertex["y"] for vertex in vertices]
216
+ min_x, max_x = min(x_positions), max(x_positions)
217
+ min_y, max_y = min(y_positions), max(y_positions)
218
+ avg_height = max_y - min_y
219
+ height_flat.append(avg_height)
220
+
221
+ # Assuming uniform confidence for all detected regions
222
+ confidences.append(1.0) # Placeholder confidence
223
+
224
+ # Adding dummy character for each box
225
+ characters.append("X") # Placeholder character
226
+
227
+ # Organize text as a single string (assuming each box is a word)
228
+ # organized_text += "detected_text " # Placeholder text
229
+
230
+ # Update class attributes with processed data
231
+ self.normal_bounds = bounds
232
+ self.normal_bounds_word = bounds_word
233
+ self.normal_text_to_box_mapping = text_to_box_mapping
234
+ self.normal_bounds_flat = bounds_flat # This would be similar to bounds if not processing characters individually
235
+ self.normal_height = height_flat
236
+ self.normal_confidences = confidences
237
+ self.normal_characters = characters
238
+ self.normal_organized_text = organized_text.strip()
239
+
240
+
241
+ def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
242
+ CONFIDENCES = 0.80
243
+ MAX_NEW_TOKENS = 50
244
+
245
+ self.OCR_JSON_to_file = {}
246
+
247
+ ocr_parts = ''
248
+ if not do_use_trOCR:
249
+ if 'normal' in self.OCR_option:
250
+ self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
251
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
252
+ # ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
253
+ ocr_parts = self.normal_organized_text
254
+
255
+ if 'hand' in self.OCR_option:
256
+ self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
257
+ # logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
258
+ # ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
259
+ ocr_parts = self.hand_organized_text
260
+
261
+ # if self.OCR_option in ['both',]:
262
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
263
+ # return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
264
+ return ocr_parts
265
+ else:
266
+ logger.info(f'Supplementing with trOCR')
267
+
268
+ self.trOCR_texts = []
269
+ original_image = Image.open(self.path).convert("RGB")
270
+
271
+ if 'normal' in self.OCR_option or 'CRAFT' in self.OCR_option:
272
+ available_bounds = self.normal_bounds_word
273
+ elif 'hand' in self.OCR_option:
274
+ available_bounds = self.hand_bounds_word
275
+ # elif self.OCR_option in ['both',]:
276
+ # available_bounds = self.hand_bounds_word
277
+ else:
278
+ raise
279
+
280
+ text_to_box_mapping = []
281
+ characters = []
282
+ height = []
283
+ confidences = []
284
+ total_b = len(available_bounds)
285
+ i=0
286
+ for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
287
+ i+=1
288
+ if self.json_report:
289
+ self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
290
+
291
+ vertices = bound["vertices"]
292
+
293
+ left = min([v["x"] for v in vertices])
294
+ top = min([v["y"] for v in vertices])
295
+ right = max([v["x"] for v in vertices])
296
+ bottom = max([v["y"] for v in vertices])
297
+
298
+ # Crop image based on Google's bounding box
299
+ cropped_image = original_image.crop((left, top, right, bottom))
300
+ pixel_values = self.trOCR_processor(cropped_image, return_tensors="pt").pixel_values
301
+
302
+ # Move pixel values to the appropriate device
303
+ pixel_values = pixel_values.to(self.device)
304
+
305
+ generated_ids = self.trOCR_model.generate(pixel_values, max_new_tokens=MAX_NEW_TOKENS)
306
+ extracted_text = self.trOCR_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
307
+ self.trOCR_texts.append(extracted_text)
308
+
309
+ # For plotting
310
+ word_length = max(vertex.get('x') for vertex in vertices) - min(vertex.get('x') for vertex in vertices)
311
+ num_symbols = len(extracted_text)
312
+
313
+ Yw = max(vertex.get('y') for vertex in vertices)
314
+ Yo = Yw - min(vertex.get('y') for vertex in vertices)
315
+ X = word_length / num_symbols if num_symbols > 0 else 0
316
+ H = int(X+(Yo*0.1))
317
+ height.append(H)
318
+
319
+ map_dict = {
320
+ "vertices": vertices,
321
+ "text": extracted_text # Use the text extracted by trOCR
322
+ }
323
+ text_to_box_mapping.append(map_dict)
324
+
325
+ characters.append(extracted_text)
326
+ confidences.append(CONFIDENCES)
327
+
328
+ median_height = statistics.median(height) if height else 0
329
+ median_heights = [median_height * 1.5] * len(characters)
330
+
331
+ self.trOCR_texts = ' '.join(self.trOCR_texts)
332
+
333
+ self.trOCR_text_to_box_mapping = text_to_box_mapping
334
+ self.trOCR_bounds_flat = available_bounds
335
+ self.trOCR_height = median_heights
336
+ self.trOCR_confidences = confidences
337
+ self.trOCR_characters = characters
338
+
339
+ if 'normal' in self.OCR_option:
340
+ self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
341
+ self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
342
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
343
+ # ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
344
+ ocr_parts = self.trOCR_texts
345
+ if 'hand' in self.OCR_option:
346
+ self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
347
+ self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
348
+ # logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
349
+ # ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
350
+ ocr_parts = self.trOCR_texts
351
+ # if self.OCR_option in ['both',]:
352
+ # self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
353
+ # self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
354
+ # self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
355
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
356
+ # ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
357
+ if 'CRAFT' in self.OCR_option:
358
+ # self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
359
+ self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
360
+ # logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
361
+ # ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
362
+ ocr_parts = self.trOCR_texts
363
+ return ocr_parts
364
+
365
+ @staticmethod
366
+ def confidence_to_color(confidence):
367
+ hue = (confidence - 0.5) * 120 / 0.5
368
+ r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 1)
369
+ return (int(r*255), int(g*255), int(b*255))
370
+
371
+
372
+ def render_text_on_black_image(self, option):
373
+ bounds_flat = getattr(self, f'{option}_bounds_flat', [])
374
+ heights = getattr(self, f'{option}_height', [])
375
+ confidences = getattr(self, f'{option}_confidences', [])
376
+ characters = getattr(self, f'{option}_characters', [])
377
+
378
+ original_image = Image.open(self.path)
379
+ width, height = original_image.size
380
+ black_image = Image.new("RGB", (width, height), "black")
381
+ draw = ImageDraw.Draw(black_image)
382
+
383
+ for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
384
+ font_size = int(char_height)
385
+ try:
386
+ font = ImageFont.truetype("arial.ttf", font_size)
387
+ except:
388
+ font = ImageFont.load_default().font_variant(size=font_size)
389
+ if option == 'trOCR':
390
+ color = (0, 170, 255)
391
+ else:
392
+ color = OCREngine.confidence_to_color(confidence)
393
+ position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
394
+ draw.text(position, character, fill=color, font=font)
395
+
396
+ return black_image
397
+
398
+
399
+ def merge_images(self, image1, image2):
400
+ width1, height1 = image1.size
401
+ width2, height2 = image2.size
402
+ merged_image = Image.new("RGB", (width1 + width2, max([height1, height2])))
403
+ merged_image.paste(image1, (0, 0))
404
+ merged_image.paste(image2, (width1, 0))
405
+ return merged_image
406
+
407
+
408
+ def draw_boxes(self, option):
409
+ bounds = getattr(self, f'{option}_bounds', [])
410
+ bounds_word = getattr(self, f'{option}_bounds_word', [])
411
+ confidences = getattr(self, f'{option}_confidences', [])
412
+
413
+ draw = ImageDraw.Draw(self.image)
414
+ width, height = self.image.size
415
+ if min([width, height]) > 4000:
416
+ line_width_thick = int((width + height) / 2 * 0.0025) # Adjust line width for character level
417
+ line_width_thin = 1
418
+ else:
419
+ line_width_thick = int((width + height) / 2 * 0.005) # Adjust line width for character level
420
+ line_width_thin = 1 #int((width + height) / 2 * 0.001)
421
+
422
+ for bound in bounds_word:
423
+ draw.polygon(
424
+ [
425
+ bound["vertices"][0]["x"], bound["vertices"][0]["y"],
426
+ bound["vertices"][1]["x"], bound["vertices"][1]["y"],
427
+ bound["vertices"][2]["x"], bound["vertices"][2]["y"],
428
+ bound["vertices"][3]["x"], bound["vertices"][3]["y"],
429
+ ],
430
+ outline=OCREngine.BBOX_COLOR,
431
+ width=line_width_thin
432
+ )
433
+
434
+ # Draw a line segment at the bottom of each handwritten character
435
+ for bound, confidence in zip(bounds, confidences):
436
+ color = OCREngine.confidence_to_color(confidence)
437
+ # Use the bottom two vertices of the bounding box for the line
438
+ bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
439
+ bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
440
+ draw.line([bottom_left, bottom_right], fill=color, width=line_width_thick)
441
+
442
+ return self.image
443
+
444
+
445
+ def detect_text(self):
446
+
447
+ with io.open(self.path, 'rb') as image_file:
448
+ content = image_file.read()
449
+ image = vision.Image(content=content)
450
+ response = self.client.document_text_detection(image=image)
451
+ texts = response.text_annotations
452
+
453
+ if response.error.message:
454
+ raise Exception(
455
+ '{}\nFor more info on error messages, check: '
456
+ 'https://cloud.google.com/apis/design/errors'.format(
457
+ response.error.message))
458
+
459
+ bounds = []
460
+ bounds_word = []
461
+ text_to_box_mapping = []
462
+ bounds_flat = []
463
+ height_flat = []
464
+ confidences = []
465
+ characters = []
466
+ organized_text = ""
467
+ paragraph_count = 0
468
+
469
+ for text in texts[1:]:
470
+ vertices = [{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices]
471
+ map_dict = {
472
+ "vertices": vertices,
473
+ "text": text.description
474
+ }
475
+ text_to_box_mapping.append(map_dict)
476
+
477
+ for page in response.full_text_annotation.pages:
478
+ for block in page.blocks:
479
+ # paragraph_count += 1
480
+ # organized_text += f'OCR_paragraph_{paragraph_count}:\n' # Add paragraph label
481
+ for paragraph in block.paragraphs:
482
+
483
+ avg_H_list = []
484
+ for word in paragraph.words:
485
+ Yw = max(vertex.y for vertex in word.bounding_box.vertices)
486
+ # Calculate the width of the word and divide by the number of symbols
487
+ word_length = max(vertex.x for vertex in word.bounding_box.vertices) - min(vertex.x for vertex in word.bounding_box.vertices)
488
+ num_symbols = len(word.symbols)
489
+ if num_symbols <= 3:
490
+ H = int(Yw - min(vertex.y for vertex in word.bounding_box.vertices))
491
+ else:
492
+ Yo = Yw - min(vertex.y for vertex in word.bounding_box.vertices)
493
+ X = word_length / num_symbols if num_symbols > 0 else 0
494
+ H = int(X+(Yo*0.1))
495
+ avg_H_list.append(H)
496
+ avg_H = int(mean(avg_H_list))
497
+
498
+ words_in_para = []
499
+ for word in paragraph.words:
500
+ # Get word-level bounding box
501
+ bound_word_dict = {
502
+ "vertices": [
503
+ {"x": vertex.x, "y": vertex.y} for vertex in word.bounding_box.vertices
504
+ ]
505
+ }
506
+ bounds_word.append(bound_word_dict)
507
+
508
+ Y = max(vertex.y for vertex in word.bounding_box.vertices)
509
+ word_x_start = min(vertex.x for vertex in word.bounding_box.vertices)
510
+ word_x_end = max(vertex.x for vertex in word.bounding_box.vertices)
511
+ num_symbols = len(word.symbols)
512
+ symbol_width = (word_x_end - word_x_start) / num_symbols if num_symbols > 0 else 0
513
+
514
+ current_x_position = word_x_start
515
+
516
+ characters_ind = []
517
+ for symbol in word.symbols:
518
+ bound_dict = {
519
+ "vertices": [
520
+ {"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
521
+ ]
522
+ }
523
+ bounds.append(bound_dict)
524
+
525
+ # Create flat bounds with adjusted x position
526
+ bounds_flat_dict = {
527
+ "vertices": [
528
+ {"x": current_x_position, "y": Y},
529
+ {"x": current_x_position + symbol_width, "y": Y}
530
+ ]
531
+ }
532
+ bounds_flat.append(bounds_flat_dict)
533
+ current_x_position += symbol_width
534
+
535
+ height_flat.append(avg_H)
536
+ confidences.append(round(symbol.confidence, 4))
537
+
538
+ characters_ind.append(symbol.text)
539
+ characters.append(symbol.text)
540
+
541
+ words_in_para.append(''.join(characters_ind))
542
+ paragraph_text = ' '.join(words_in_para) # Join words in paragraph
543
+ organized_text += paragraph_text + ' ' #+ '\n'
544
+
545
+ # median_height = statistics.median(height_flat) if height_flat else 0
546
+ # median_heights = [median_height] * len(characters)
547
+
548
+ self.normal_cleaned_text = texts[0].description if texts else ''
549
+ self.normal_organized_text = organized_text
550
+ self.normal_bounds = bounds
551
+ self.normal_bounds_word = bounds_word
552
+ self.normal_text_to_box_mapping = text_to_box_mapping
553
+ self.normal_bounds_flat = bounds_flat
554
+ # self.normal_height = median_heights #height_flat
555
+ self.normal_height = height_flat
556
+ self.normal_confidences = confidences
557
+ self.normal_characters = characters
558
+ return self.normal_cleaned_text
559
+
560
+
561
+ def detect_handwritten_ocr(self):
562
+
563
+ with open(self.path, "rb") as image_file:
564
+ content = image_file.read()
565
+
566
+ image = vision_beta.Image(content=content)
567
+ image_context = vision_beta.ImageContext(language_hints=["en-t-i0-handwrit"])
568
+ response = self.client_beta.document_text_detection(image=image, image_context=image_context)
569
+ texts = response.text_annotations
570
+
571
+ if response.error.message:
572
+ raise Exception(
573
+ "{}\nFor more info on error messages, check: "
574
+ "https://cloud.google.com/apis/design/errors".format(response.error.message)
575
+ )
576
+
577
+ bounds = []
578
+ bounds_word = []
579
+ bounds_flat = []
580
+ height_flat = []
581
+ confidences = []
582
+ characters = []
583
+ organized_text = ""
584
+ paragraph_count = 0
585
+ text_to_box_mapping = []
586
+
587
+ for text in texts[1:]:
588
+ vertices = [{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices]
589
+ map_dict = {
590
+ "vertices": vertices,
591
+ "text": text.description
592
+ }
593
+ text_to_box_mapping.append(map_dict)
594
+
595
+ for page in response.full_text_annotation.pages:
596
+ for block in page.blocks:
597
+ # paragraph_count += 1
598
+ # organized_text += f'\nOCR_paragraph_{paragraph_count}:\n' # Add paragraph label
599
+ for paragraph in block.paragraphs:
600
+
601
+ avg_H_list = []
602
+ for word in paragraph.words:
603
+ Yw = max(vertex.y for vertex in word.bounding_box.vertices)
604
+ # Calculate the width of the word and divide by the number of symbols
605
+ word_length = max(vertex.x for vertex in word.bounding_box.vertices) - min(vertex.x for vertex in word.bounding_box.vertices)
606
+ num_symbols = len(word.symbols)
607
+ if num_symbols <= 3:
608
+ H = int(Yw - min(vertex.y for vertex in word.bounding_box.vertices))
609
+ else:
610
+ Yo = Yw - min(vertex.y for vertex in word.bounding_box.vertices)
611
+ X = word_length / num_symbols if num_symbols > 0 else 0
612
+ H = int(X+(Yo*0.1))
613
+ avg_H_list.append(H)
614
+ avg_H = int(mean(avg_H_list))
615
+
616
+ words_in_para = []
617
+ for word in paragraph.words:
618
+ # Get word-level bounding box
619
+ bound_word_dict = {
620
+ "vertices": [
621
+ {"x": vertex.x, "y": vertex.y} for vertex in word.bounding_box.vertices
622
+ ]
623
+ }
624
+ bounds_word.append(bound_word_dict)
625
+
626
+ Y = max(vertex.y for vertex in word.bounding_box.vertices)
627
+ word_x_start = min(vertex.x for vertex in word.bounding_box.vertices)
628
+ word_x_end = max(vertex.x for vertex in word.bounding_box.vertices)
629
+ num_symbols = len(word.symbols)
630
+ symbol_width = (word_x_end - word_x_start) / num_symbols if num_symbols > 0 else 0
631
+
632
+ current_x_position = word_x_start
633
+
634
+ characters_ind = []
635
+ for symbol in word.symbols:
636
+ bound_dict = {
637
+ "vertices": [
638
+ {"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
639
+ ]
640
+ }
641
+ bounds.append(bound_dict)
642
+
643
+ # Create flat bounds with adjusted x position
644
+ bounds_flat_dict = {
645
+ "vertices": [
646
+ {"x": current_x_position, "y": Y},
647
+ {"x": current_x_position + symbol_width, "y": Y}
648
+ ]
649
+ }
650
+ bounds_flat.append(bounds_flat_dict)
651
+ current_x_position += symbol_width
652
+
653
+ height_flat.append(avg_H)
654
+ confidences.append(round(symbol.confidence, 4))
655
+
656
+ characters_ind.append(symbol.text)
657
+ characters.append(symbol.text)
658
+
659
+ words_in_para.append(''.join(characters_ind))
660
+ paragraph_text = ' '.join(words_in_para) # Join words in paragraph
661
+ organized_text += paragraph_text + ' ' #+ '\n'
662
+
663
+ # median_height = statistics.median(height_flat) if height_flat else 0
664
+ # median_heights = [median_height] * len(characters)
665
+
666
+ self.hand_cleaned_text = response.text_annotations[0].description if response.text_annotations else ''
667
+ self.hand_organized_text = organized_text
668
+ self.hand_bounds = bounds
669
+ self.hand_bounds_word = bounds_word
670
+ self.hand_bounds_flat = bounds_flat
671
+ self.hand_text_to_box_mapping = text_to_box_mapping
672
+ # self.hand_height = median_heights #height_flat
673
+ self.hand_height = height_flat
674
+ self.hand_confidences = confidences
675
+ self.hand_characters = characters
676
+ return self.hand_cleaned_text
677
+
678
+
679
+ def process_image(self, do_create_OCR_helper_image, logger):
680
+ # Can stack options, so solitary if statements
681
+ self.OCR = 'OCR:\n'
682
+ if 'CRAFT' in self.OCR_option:
683
+ self.do_use_trOCR = True
684
+ self.detect_text_craft()
685
+ ### Optionally add trOCR to the self.OCR for additional context
686
+ if self.double_OCR:
687
+ part_OCR = "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
688
+ self.OCR = self.OCR + part_OCR + part_OCR
689
+ else:
690
+ self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
691
+ # logger.info(f"CRAFT trOCR:\n{self.OCR}")
692
+
693
+ if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
694
+ if self.json_report:
695
+ self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
696
+
697
+ image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
698
+ self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
699
+
700
+ try:
701
+ self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
702
+ except:
703
+ self.OCR_JSON_to_file = {}
704
+ self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
705
+
706
+ if self.double_OCR:
707
+ self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
708
+ else:
709
+ self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
710
+ # logger.info(f"LLaVA OCR:\n{self.OCR}")
711
+
712
+ if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
713
+ if 'normal' in self.OCR_option:
714
+ if self.double_OCR:
715
+ part_OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
716
+ self.OCR = self.OCR + part_OCR + part_OCR
717
+ else:
718
+ self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
719
+ if 'hand' in self.OCR_option:
720
+ if self.double_OCR:
721
+ part_OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
722
+ self.OCR = self.OCR + part_OCR + part_OCR
723
+ else:
724
+ self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
725
+ # if self.OCR_option not in ['normal', 'hand', 'both']:
726
+ # self.OCR_option = 'both'
727
+ # self.detect_text()
728
+ # self.detect_handwritten_ocr()
729
+
730
+ ### Optionally add trOCR to the self.OCR for additional context
731
+ if self.do_use_trOCR:
732
+ if self.double_OCR:
733
+ part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
734
+ self.OCR = self.OCR + part_OCR + part_OCR
735
+ else:
736
+ self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
737
+ # logger.info(f"OCR:\n{self.OCR}")
738
+ else:
739
+ # populate self.OCR_JSON_to_file = {}
740
+ _ = self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
741
+
742
+
743
+ if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
744
+ self.image = Image.open(self.path)
745
+
746
+ if 'normal' in self.OCR_option:
747
+ image_with_boxes_normal = self.draw_boxes('normal')
748
+ text_image_normal = self.render_text_on_black_image('normal')
749
+ self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
750
+
751
+ if 'hand' in self.OCR_option:
752
+ image_with_boxes_hand = self.draw_boxes('hand')
753
+ text_image_hand = self.render_text_on_black_image('hand')
754
+ self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
755
+
756
+ if self.do_use_trOCR:
757
+ text_image_trOCR = self.render_text_on_black_image('trOCR')
758
+
759
+ if 'CRAFT' in self.OCR_option:
760
+ image_with_boxes_normal = self.draw_boxes('normal')
761
+ self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
762
+
763
+ ### Merge final overlay image
764
+ ### [original, normal bboxes, normal text]
765
+ if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
766
+ self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
767
+ ### [original, hand bboxes, hand text]
768
+ elif 'hand' in self.OCR_option:
769
+ self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
770
+ ### [original, normal bboxes, normal text, hand bboxes, hand text]
771
+ else:
772
+ self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
773
+
774
+ if self.do_use_trOCR:
775
+ if 'CRAFT' in self.OCR_option:
776
+ heat_map_text = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["text_score_heatmap"], cv2.COLOR_BGR2RGB))
777
+ heat_map_link = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["link_score_heatmap"], cv2.COLOR_BGR2RGB))
778
+ self.overlay_image = self.merge_images(self.overlay_image, heat_map_text)
779
+ self.overlay_image = self.merge_images(self.overlay_image, heat_map_link)
780
+
781
+ else:
782
+ self.overlay_image = self.merge_images(self.overlay_image, text_image_trOCR)
783
+
784
+ else:
785
+ self.merged_image_normal = None
786
+ self.merged_image_hand = None
787
+ self.overlay_image = Image.open(self.path)
788
+
789
+ try:
790
+ from craft_text_detector import empty_cuda_cache
791
+ empty_cuda_cache()
792
+ except:
793
+ pass
794
+
795
+ class SafetyCheck():
796
+ def __init__(self, is_hf) -> None:
797
+ self.is_hf = is_hf
798
+ self.set_client()
799
+
800
+ def set_client(self):
801
+ if self.is_hf:
802
+ self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
803
+ else:
804
+ self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
805
+
806
+ def get_google_credentials(self):
807
+ creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
808
+ credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
809
+ return credentials
810
+
811
+ def check_for_inappropriate_content(self, file_stream):
812
+ try:
813
+ LEVEL = 2
814
+ # content = file_stream.read()
815
+ file_stream.seek(0) # Reset file stream position to the beginning
816
+ content = file_stream.read()
817
+ image = vision.Image(content=content)
818
+ response = self.client.safe_search_detection(image=image)
819
+ safe = response.safe_search_annotation
820
+
821
+ likelihood_name = (
822
+ "UNKNOWN",
823
+ "VERY_UNLIKELY",
824
+ "UNLIKELY",
825
+ "POSSIBLE",
826
+ "LIKELY",
827
+ "VERY_LIKELY",
828
+ )
829
+ print("Safe search:")
830
+
831
+ print(f" adult*: {likelihood_name[safe.adult]}")
832
+ print(f" medical*: {likelihood_name[safe.medical]}")
833
+ print(f" spoofed: {likelihood_name[safe.spoof]}")
834
+ print(f" violence*: {likelihood_name[safe.violence]}")
835
+ print(f" racy: {likelihood_name[safe.racy]}")
836
+
837
+ # Check the levels of adult, violence, racy, etc. content.
838
+ if (safe.adult > LEVEL or
839
+ safe.medical > LEVEL or
840
+ # safe.spoof > LEVEL or
841
+ safe.violence > LEVEL #or
842
+ # safe.racy > LEVEL
843
+ ):
844
+ print("Found violation")
845
+ return True # The image violates safe search guidelines.
846
+
847
+ print("Found NO violation")
848
+ return False # The image is considered safe.
849
+ except:
850
+ return False # The image is considered safe. TEMPOROARY FIX TODO
vouchervision/OCR_google_cloud_vision.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
7
  import colorsys
8
  from tqdm import tqdm
9
  from google.oauth2 import service_account
10
-
11
  ### LLaVA should only be installed if the user will actually use it.
12
  ### It requires the most recent pytorch/Python and can mess with older systems
13
 
@@ -43,6 +43,7 @@ class OCREngine:
43
  self.path = path
44
  self.cfg = cfg
45
  self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
 
46
  self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
47
  self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
48
  self.dir_home = dir_home
@@ -53,6 +54,8 @@ class OCREngine:
53
  self.trOCR_model = trOCR_model
54
  self.device = device
55
 
 
 
56
  self.hand_cleaned_text = None
57
  self.hand_organized_text = None
58
  self.hand_bounds = None
@@ -80,6 +83,7 @@ class OCREngine:
80
  self.trOCR_confidences = None
81
  self.trOCR_characters = None
82
  self.set_client()
 
83
  self.init_craft()
84
 
85
  self.multimodal_prompt = """I need you to transcribe all of the text in this image.
@@ -117,6 +121,10 @@ class OCREngine:
117
  else:
118
  self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
119
 
 
 
 
 
120
  def init_llava(self):
121
  if 'LLaVA' in self.OCR_option:
122
  from vouchervision.OCR_llava import OCRllava
@@ -241,8 +249,6 @@ class OCREngine:
241
  def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
242
  CONFIDENCES = 0.80
243
  MAX_NEW_TOKENS = 50
244
-
245
- self.OCR_JSON_to_file = {}
246
 
247
  ocr_parts = ''
248
  if not do_use_trOCR:
@@ -677,6 +683,9 @@ class OCREngine:
677
 
678
 
679
  def process_image(self, do_create_OCR_helper_image, logger):
 
 
 
680
  # Can stack options, so solitary if statements
681
  self.OCR = 'OCR:\n'
682
  if 'CRAFT' in self.OCR_option:
@@ -697,11 +706,7 @@ class OCREngine:
697
  image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
698
  self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
699
 
700
- try:
701
- self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
702
- except:
703
- self.OCR_JSON_to_file = {}
704
- self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
705
 
706
  if self.double_OCR:
707
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
@@ -709,6 +714,20 @@ class OCREngine:
709
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
710
  # logger.info(f"LLaVA OCR:\n{self.OCR}")
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
713
  if 'normal' in self.OCR_option:
714
  if self.double_OCR:
@@ -762,14 +781,16 @@ class OCREngine:
762
 
763
  ### Merge final overlay image
764
  ### [original, normal bboxes, normal text]
765
- if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
766
- self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
767
- ### [original, hand bboxes, hand text]
768
- elif 'hand' in self.OCR_option:
769
- self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
770
- ### [original, normal bboxes, normal text, hand bboxes, hand text]
771
- else:
772
- self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
 
 
773
 
774
  if self.do_use_trOCR:
775
  if 'CRAFT' in self.OCR_option:
 
7
  import colorsys
8
  from tqdm import tqdm
9
  from google.oauth2 import service_account
10
+ from OCR_Florence_2 import FlorenceOCR
11
  ### LLaVA should only be installed if the user will actually use it.
12
  ### It requires the most recent pytorch/Python and can mess with older systems
13
 
 
43
  self.path = path
44
  self.cfg = cfg
45
  self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
46
+ self.do_use_florence = self.cfg['leafmachine']['project']['do_use_florence']
47
  self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
48
  self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
49
  self.dir_home = dir_home
 
54
  self.trOCR_model = trOCR_model
55
  self.device = device
56
 
57
+ self.OCR_JSON_to_file = {}
58
+
59
  self.hand_cleaned_text = None
60
  self.hand_organized_text = None
61
  self.hand_bounds = None
 
83
  self.trOCR_confidences = None
84
  self.trOCR_characters = None
85
  self.set_client()
86
+ self.init_florence()
87
  self.init_craft()
88
 
89
  self.multimodal_prompt = """I need you to transcribe all of the text in this image.
 
121
  else:
122
  self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
123
 
124
+ def init_florence(self):
125
+ if 'Florence-2' in self.OCR_option:
126
+ self.Florence = FlorenceOCR(logger=self.logger, model_id=self.cfg['leafmachine']['project']['florence_model_path'])
127
+
128
  def init_llava(self):
129
  if 'LLaVA' in self.OCR_option:
130
  from vouchervision.OCR_llava import OCRllava
 
249
  def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
250
  CONFIDENCES = 0.80
251
  MAX_NEW_TOKENS = 50
 
 
252
 
253
  ocr_parts = ''
254
  if not do_use_trOCR:
 
683
 
684
 
685
  def process_image(self, do_create_OCR_helper_image, logger):
686
+ if 'hand' not in self.OCR_option and 'normal' not in self.OCR_option:
687
+ do_create_OCR_helper_image = False
688
+
689
  # Can stack options, so solitary if statements
690
  self.OCR = 'OCR:\n'
691
  if 'CRAFT' in self.OCR_option:
 
706
  image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
707
  self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
708
 
709
+ self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
 
 
 
 
710
 
711
  if self.double_OCR:
712
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
 
714
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
715
  # logger.info(f"LLaVA OCR:\n{self.OCR}")
716
 
717
+ if 'Florence-2' in self.OCR_option: # This option does not produce an OCR helper image
718
+ if self.json_report:
719
+ self.json_report.set_text(text_main=f'Working on Florence-2 [{self.Florence.model_id}] transcription :construction:')
720
+
721
+ self.logger.info(f"Florence-2 Usage Report for Model [{self.Florence.model_id}]")
722
+ results_text, results_text_dirty, results, usage_report = self.Florence.ocr_florence(self.path, task_prompt='<OCR>', text_input=None)
723
+
724
+ self.OCR_JSON_to_file['OCR_Florence'] = results_text
725
+
726
+ if self.double_OCR:
727
+ self.OCR = self.OCR + f"\nFlorence-2 OCR:\n{results_text}" + f"\nFlorence-2 OCR:\n{results_text}"
728
+ else:
729
+ self.OCR = self.OCR + f"\nFlorence-2 OCR:\n{results_text}"
730
+
731
  if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
732
  if 'normal' in self.OCR_option:
733
  if self.double_OCR:
 
781
 
782
  ### Merge final overlay image
783
  ### [original, normal bboxes, normal text]
784
+ if 'hand' in self.OCR_option or 'normal' in self.OCR_option:
785
+ if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
786
+ self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
787
+ ### [original, hand bboxes, hand text]
788
+ elif 'hand' in self.OCR_option:
789
+ self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
790
+ ### [original, normal bboxes, normal text, hand bboxes, hand text]
791
+ else:
792
+ self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
793
+
794
 
795
  if self.do_use_trOCR:
796
  if 'CRAFT' in self.OCR_option:
vouchervision/VoucherVision_Config_Builder.py CHANGED
@@ -36,21 +36,22 @@ def build_VV_config(loaded_cfg=None):
36
  save_cropped_annotations = ['label','barcode']
37
 
38
  do_use_trOCR = False
 
39
  trOCR_model_path = "microsoft/trocr-large-handwritten"
 
40
  OCR_option = 'hand'
41
  OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
42
  OCR_option_llava_bit = 'full' # full or 4bit
43
  double_OCR = False
44
 
45
-
46
  tool_GEO = True
47
  tool_WFO = True
48
  tool_wikipedia = True
49
 
50
  check_for_illegal_filenames = False
51
 
52
- LLM_version_user = 'Azure GPT 4' #'Azure GPT 4 Turbo 1106-preview'
53
- prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
54
  use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
55
  do_create_OCR_helper_image = True
56
 
@@ -71,7 +72,7 @@ def build_VV_config(loaded_cfg=None):
71
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
72
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
73
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
74
- prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
75
  OCR_option_llava_bit, double_OCR, save_cropped_annotations,
76
  tool_GEO, tool_WFO, tool_wikipedia,
77
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
@@ -88,7 +89,9 @@ def build_VV_config(loaded_cfg=None):
88
  catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
89
 
90
  do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
 
91
  trOCR_model_path = loaded_cfg['leafmachine']['project']['trOCR_model_path']
 
92
  OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
93
  OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
94
  OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
@@ -118,7 +121,7 @@ def build_VV_config(loaded_cfg=None):
118
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
119
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
120
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
121
- prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
122
  OCR_option_llava_bit, double_OCR, save_cropped_annotations,
123
  tool_GEO, tool_WFO, tool_wikipedia,
124
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
@@ -127,7 +130,7 @@ def build_VV_config(loaded_cfg=None):
127
  def assemble_config(dir_home, run_name, dir_images_local,dir_output,
128
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
129
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
130
- prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
131
  OCR_option_llava_bit, double_OCR, save_cropped_annotations,
132
  tool_GEO, tool_WFO, tool_wikipedia,
133
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
@@ -174,7 +177,9 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
174
  'delete_all_temps': False,
175
  'delete_temps_keep_VVE': False,
176
  'do_use_trOCR': do_use_trOCR,
 
177
  'trOCR_model_path': trOCR_model_path,
 
178
  'OCR_option': OCR_option,
179
  'OCR_option_llava': OCR_option_llava,
180
  'OCR_option_llava_bit': OCR_option_llava_bit,
 
36
  save_cropped_annotations = ['label','barcode']
37
 
38
  do_use_trOCR = False
39
+ do_use_florence = False
40
  trOCR_model_path = "microsoft/trocr-large-handwritten"
41
+ florence_model_path = "microsoft/Florence-2-large"
42
  OCR_option = 'hand'
43
  OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
44
  OCR_option_llava_bit = 'full' # full or 4bit
45
  double_OCR = False
46
 
 
47
  tool_GEO = True
48
  tool_WFO = True
49
  tool_wikipedia = True
50
 
51
  check_for_illegal_filenames = False
52
 
53
+ LLM_version_user = 'Gemini 1.5 Flash' # 'Azure GPT 4' #'Azure GPT 4 Turbo 1106-preview'
54
+ prompt_version = 'SLTPvM_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
55
  use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
56
  do_create_OCR_helper_image = True
57
 
 
72
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
73
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
74
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
75
+ prompt_version, do_create_OCR_helper_image, do_use_trOCR, do_use_florence, trOCR_model_path, florence_model_path, OCR_option, OCR_option_llava,
76
  OCR_option_llava_bit, double_OCR, save_cropped_annotations,
77
  tool_GEO, tool_WFO, tool_wikipedia,
78
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
 
89
  catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
90
 
91
  do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
92
+ do_use_florence = loaded_cfg['leafmachine']['project']['do_use_florence']
93
  trOCR_model_path = loaded_cfg['leafmachine']['project']['trOCR_model_path']
94
+ florence_model_path = loaded_cfg['leafmachine']['project']['florence_model_path']
95
  OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
96
  OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
97
  OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
 
121
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
122
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
123
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
124
+ prompt_version, do_create_OCR_helper_image, do_use_trOCR, do_use_florence, trOCR_model_path, florence_model_path, OCR_option, OCR_option_llava,
125
  OCR_option_llava_bit, double_OCR, save_cropped_annotations,
126
  tool_GEO, tool_WFO, tool_wikipedia,
127
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
 
130
  def assemble_config(dir_home, run_name, dir_images_local,dir_output,
131
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
132
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
133
+ prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, do_use_florence, trOCR_model_path, florence_model_path, OCR_option, OCR_option_llava,
134
  OCR_option_llava_bit, double_OCR, save_cropped_annotations,
135
  tool_GEO, tool_WFO, tool_wikipedia,
136
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
 
177
  'delete_all_temps': False,
178
  'delete_temps_keep_VVE': False,
179
  'do_use_trOCR': do_use_trOCR,
180
+ 'do_use_florence': do_use_florence,
181
  'trOCR_model_path': trOCR_model_path,
182
+ 'florence_model_path': florence_model_path,
183
  'OCR_option': OCR_option,
184
  'OCR_option_llava': OCR_option_llava,
185
  'OCR_option_llava_bit': OCR_option_llava_bit,
vouchervision/fetch_data.py CHANGED
@@ -7,7 +7,7 @@ import urllib.request
7
  from tqdm import tqdm
8
  import subprocess
9
 
10
- VERSION = 'v-2-1'
11
 
12
  def fetch_data(logger, dir_home, cfg_file_path):
13
  logger.name = 'Fetch Data'
 
7
  from tqdm import tqdm
8
  import subprocess
9
 
10
+ VERSION = 'v-2-3'
11
 
12
  def fetch_data(logger, dir_home, cfg_file_path):
13
  logger.name = 'Fetch Data'
vouchervision/generate_partner_collage.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import time
6
+
7
+ # Global variables
8
+ H = 200
9
+ ROWS = 6
10
+ PADDING = 30
11
+
12
+ # Step 1: Fetch the Images from the URL Folder
13
+ def fetch_image_urls(url):
14
+ response = requests.get(url + '?t=' + str(time.time()))
15
+ soup = BeautifulSoup(response.content, 'html.parser')
16
+ images = {}
17
+ for node in soup.find_all('a'):
18
+ href = node.get('href')
19
+ if href.endswith(('.png', '.jpg', '.jpeg')):
20
+ try:
21
+ image_index = int(href.split('__')[0])
22
+ images[image_index] = url + '/' + href + '?t=' + str(time.time())
23
+ except ValueError:
24
+ print(f"Skipping invalid image: {href}")
25
+ return images
26
+
27
+ # Step 2: Resize Images to Height H
28
+ def fetch_image(url):
29
+ response = requests.get(url)
30
+ return Image.open(BytesIO(response.content))
31
+
32
+ def resize_images(images, target_height):
33
+ resized_images = {}
34
+ for index, img in images.items():
35
+ ratio = target_height / img.height
36
+ new_width = int(img.width * ratio)
37
+ resized_img = img.resize((new_width, target_height), Image.BICUBIC)
38
+ resized_images[index] = resized_img
39
+ return resized_images
40
+
41
+ # Step 3: Create a Collage with Efficient Placement Algorithm
42
+ def create_collage(image_urls, collage_path, H, ROWS, PADDING):
43
+ images = {index: fetch_image(url) for index, url in image_urls.items()}
44
+ resized_images = resize_images(images, H) # Resize to H pixels height
45
+
46
+ center_image = resized_images.pop(0)
47
+ other_images = list(resized_images.items())
48
+
49
+ # Calculate collage size based on the number of rows
50
+ collage_width = 3000 # 16:9 aspect ratio width
51
+ collage_height = (H + PADDING) * ROWS + 2 * PADDING # Adjust height based on number of rows, add padding to top and bottom
52
+ collage = Image.new('RGB', (collage_width, collage_height), (255, 255, 255))
53
+
54
+ # Sort images by width and height
55
+ sorted_images = sorted(other_images, key=lambda x: x[1].width * x[1].height, reverse=True)
56
+
57
+ # Create alternate placement list and insert the center image in the middle
58
+ alternate_images = []
59
+ i, j = 0, len(sorted_images) - 1
60
+ halfway_point = (len(sorted_images) + 1) // 2
61
+ count = 0
62
+
63
+ while i <= j:
64
+ if count == halfway_point:
65
+ alternate_images.append((0, center_image))
66
+ if i == j:
67
+ alternate_images.append(sorted_images[i])
68
+ else:
69
+ alternate_images.append(sorted_images[i])
70
+ alternate_images.append(sorted_images[j])
71
+ i += 1
72
+ j -= 1
73
+ count += 2
74
+
75
+ # Calculate number of images per row
76
+ images_per_row = len(alternate_images) // ROWS
77
+ extra_images = len(alternate_images) % ROWS
78
+
79
+ # Place images in rows with only padding space between them
80
+ def place_images_in_rows(images, collage, max_width, padding, row_height, rows, images_per_row, extra_images):
81
+ y = padding
82
+ for current_row in range(rows):
83
+ row_images_count = images_per_row + (1 if extra_images > 0 else 0)
84
+ extra_images -= 1 if extra_images > 0 else 0
85
+ row_images = images[:row_images_count]
86
+ row_width = sum(img.width for idx, img in row_images) + padding * (row_images_count - 1)
87
+ x = (max_width - row_width) // 2
88
+ for idx, img in row_images:
89
+ collage.paste(img, (x, y))
90
+ x += img.width + padding
91
+ y += row_height + padding
92
+ images = images[row_images_count:]
93
+
94
+ place_images_in_rows(alternate_images, collage, collage_width, PADDING, H, ROWS, images_per_row, extra_images)
95
+
96
+ collage.save(collage_path)
97
+
98
+ # Define the URL folder and other constants
99
+ url_folder = 'https://leafmachine.org/partners/'
100
+ collage_path = 'img/collage.jpg'
101
+
102
+ # Fetch, Create, and Update
103
+ image_urls = fetch_image_urls(url_folder)
104
+ create_collage(image_urls, collage_path, H, ROWS, PADDING)
vouchervision/librarian_knowledge.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "catalogNumber": "barcode identifier, at least 6 digits, fewer than 30 digits.",
3
+ "order": "full scientific name of the Order in which the taxon is classified. Order must be capitalized.",
4
+ "family": "full scientific name of the Family in which the taxon is classified. Family must be capitalized.",
5
+ "scientificName": "scientific name of the taxon including Genus, specific epithet, and any lower classifications.",
6
+ "scientificNameAuthorship": "authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.",
7
+ "genus": "taxonomic determination to Genus, Genus must be capitalized.",
8
+ "specificEpithet": "The name of the first or species epithet of the scientificName. Only include the species epithet.",
9
+ "identifiedBy": "list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.",
10
+ "recordedBy": "list of names of people, doctors, professors, groups, or organizations.",
11
+ "recordNumber": "identifier given to the specimen at the time it was recorded.",
12
+ "verbatimEventDate": "The verbatim original representation of the date and time information for when the specimen was collected.",
13
+ "eventDate": "collection date formatted as year-month-day YYYY-MM-DD.",
14
+ "habitat": "habitat.",
15
+ "occurrenceRemarks": "all descriptive text in the OCR rearranged into sensible sentences or sentence fragments.",
16
+ "country": "country or major administrative unit.",
17
+ "stateProvince": "state, province, canton, department, region, etc.",
18
+ "county": "county, shire, department, parish etc.",
19
+ "municipality": "city, municipality, etc.",
20
+ "locality": "description of geographic information aiding in pinpointing the exact origin or location of the specimen.",
21
+ "degreeOfEstablishment": "cultivated plants are intentionally grown by humans. Use either - unknown or cultivated.",
22
+ "decimalLatitude": "latitude decimal coordinate.",
23
+ "decimalLongitude": "longitude decimal coordinate.",
24
+ "verbatimCoordinates": "verbatim location coordinates.",
25
+ "minimumElevationInMeters": "minimum elevation or altitude in meters.",
26
+ "maximumElevationInMeters": "maximum elevation or altitude in meters."
27
+ }
vouchervision/save_dataset.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ # Load the dataset
4
+ dataset = load_dataset("phyloforfun/HLT_MICH_Angiospermae_SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05")
5
+
6
+ # Define the directory where you want to save the files
7
+ save_dir = "D:/Dropbox/VoucherVision/datasets/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05"
8
+
9
+ # Save each split as a JSONL file in the specified directory
10
+ for split, split_dataset in dataset.items():
11
+ split_dataset.to_json(f"{save_dir}/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05-{split}.jsonl")
12
+
13
+
14
+ '''import json # convert to google
15
+
16
+ # Load the JSONL file
17
+ input_file_path = '/mnt/data/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05-train.jsonl'
18
+ output_file_path = '/mnt/data/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05-train-converted.jsonl'
19
+
20
+ # Define the conversion function
21
+ def convert_record(record):
22
+ return {
23
+ "input_text": record.get('instruction', '') + ' ' + record.get('input', ''),
24
+ "target_text": record.get('output', '')
25
+ }
26
+
27
+ # Convert and save the new JSONL file
28
+ with open(input_file_path, 'r', encoding='utf-8') as infile, open(output_file_path, 'w', encoding='utf-8') as outfile:
29
+ for line in infile:
30
+ record = json.loads(line)
31
+ converted_record = convert_record(record)
32
+ outfile.write(json.dumps(converted_record) + '\n')
33
+
34
+ output_file_path'''
vouchervision/utils_VoucherVision.py CHANGED
@@ -14,6 +14,7 @@ from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
14
  from vouchervision.LLM_MistralAI import MistralHandler
15
  from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
16
  from vouchervision.LLM_local_MistralAI import LocalMistralHandler
 
17
  from vouchervision.prompt_catalog import PromptCatalog
18
  from vouchervision.model_maps import ModelMaps
19
  from vouchervision.general_utils import get_cfg_from_full_path
@@ -449,6 +450,8 @@ class VoucherVision():
449
  k_openai = os.getenv('OPENAI_API_KEY')
450
  k_openai_azure = os.getenv('AZURE_API_VERSION')
451
 
 
 
452
  k_google_project_id = os.getenv('GOOGLE_PROJECT_ID')
453
  k_google_location = os.getenv('GOOGLE_LOCATION')
454
  k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
@@ -464,6 +467,8 @@ class VoucherVision():
464
  k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
465
  k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
466
 
 
 
467
  k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
468
  k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
469
  k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
@@ -478,6 +483,8 @@ class VoucherVision():
478
  self.has_key_azure_openai = self.has_API_key(k_openai_azure)
479
  self.llm = None
480
 
 
 
481
  self.has_key_google_project_id = self.has_API_key(k_google_project_id)
482
  self.has_key_google_location = self.has_API_key(k_google_location)
483
  self.has_key_google_application_credentials = self.has_API_key(k_google_application_credentials)
@@ -505,6 +512,11 @@ class VoucherVision():
505
  openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
506
  os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
507
 
 
 
 
 
 
508
 
509
  ### OpenAI - Azure
510
  if self.has_key_azure_openai:
@@ -738,6 +750,10 @@ class VoucherVision():
738
  response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report, paths)
739
  else:
740
  response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, json_report, paths)
 
 
 
 
741
  else:
742
  response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, json_report, paths)
743
 
@@ -771,6 +787,8 @@ class VoucherVision():
771
  return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
772
  else:
773
  return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
 
 
774
  else:
775
  if 'PALM2' in name_parts:
776
  return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
 
14
  from vouchervision.LLM_MistralAI import MistralHandler
15
  from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
16
  from vouchervision.LLM_local_MistralAI import LocalMistralHandler
17
+ from vouchervision.LLM_local_custom_fine_tune import LocalFineTuneHandler
18
  from vouchervision.prompt_catalog import PromptCatalog
19
  from vouchervision.model_maps import ModelMaps
20
  from vouchervision.general_utils import get_cfg_from_full_path
 
450
  k_openai = os.getenv('OPENAI_API_KEY')
451
  k_openai_azure = os.getenv('AZURE_API_VERSION')
452
 
453
+ k_huggingface = None
454
+
455
  k_google_project_id = os.getenv('GOOGLE_PROJECT_ID')
456
  k_google_location = os.getenv('GOOGLE_LOCATION')
457
  k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
 
467
  k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
468
  k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
469
 
470
+ k_huggingface = self.cfg_private['huggingface']['hf_token']
471
+
472
  k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
473
  k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
474
  k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
 
483
  self.has_key_azure_openai = self.has_API_key(k_openai_azure)
484
  self.llm = None
485
 
486
+ self.has_key_huggingface = self.has_API_key(k_huggingface)
487
+
488
  self.has_key_google_project_id = self.has_API_key(k_google_project_id)
489
  self.has_key_google_location = self.has_API_key(k_google_location)
490
  self.has_key_google_application_credentials = self.has_API_key(k_google_application_credentials)
 
512
  openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
513
  os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
514
 
515
+ if self.has_key_huggingface:
516
+ if self.is_hf:
517
+ pass
518
+ else:
519
+ os.environ["HUGGING_FACE_KEY"] = self.cfg_private['huggingface']['hf_token']
520
 
521
  ### OpenAI - Azure
522
  if self.has_key_azure_openai:
 
750
  response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report, paths)
751
  else:
752
  response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, json_report, paths)
753
+
754
+ elif "/" in ''.join(name_parts):
755
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_custom_fine_tune(self.OCR, json_report, paths) ###
756
+
757
  else:
758
  response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, json_report, paths)
759
 
 
787
  return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
788
  else:
789
  return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
790
+ elif "/" in ''.join(name_parts):
791
+ return LocalFineTuneHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
792
  else:
793
  if 'PALM2' in name_parts:
794
  return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
vouchervision/utils_hf (DESKTOP-548UDCR's conflicted copy 2024-06-15).py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, re, datetime, tempfile, yaml
2
+ from googleapiclient.discovery import build
3
+ from googleapiclient.http import MediaFileUpload
4
+ from google.oauth2 import service_account
5
+ import base64
6
+ from PIL import Image
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ from shutil import copyfileobj, copyfile
10
+
11
+ # from vouchervision.general_utils import get_cfg_from_full_path
12
+
13
+
14
+ def setup_streamlit_config(dir_home):
15
+ # Define the directory path and filename
16
+ dir_path = os.path.join(dir_home, ".streamlit")
17
+ file_path = os.path.join(dir_path, "config.toml")
18
+
19
+ # Check if directory exists, if not create it
20
+ if not os.path.exists(dir_path):
21
+ os.makedirs(dir_path)
22
+
23
+ # Create or modify the file with the provided content
24
+ config_content = f"""
25
+ [theme]
26
+ base = "dark"
27
+ primaryColor = "#00ff00"
28
+
29
+ [server]
30
+ enableStaticServing = false
31
+ runOnSave = true
32
+ port = 8524
33
+ maxUploadSize = 5000
34
+ """
35
+
36
+ with open(file_path, "w") as f:
37
+ f.write(config_content.strip())
38
+
39
+
40
+ def save_uploaded_file_local(directory_in, directory_out, img_file_name, image=None):
41
+ if not os.path.exists(directory_out):
42
+ os.makedirs(directory_out)
43
+
44
+ # Assuming img_file_name includes the extension
45
+ img_file_base, img_file_ext = os.path.splitext(img_file_name)
46
+
47
+ full_path_out = os.path.join(directory_out, img_file_name)
48
+ full_path_in = os.path.join(directory_in, img_file_name)
49
+
50
+ # Check if the file extension is .pdf (or add other conditions for different file types)
51
+ if img_file_ext.lower() == '.pdf':
52
+ # Copy the file from the input directory to the output directory
53
+ copyfile(full_path_in, full_path_out)
54
+ return full_path_out
55
+ else:
56
+ if image is None:
57
+ try:
58
+ with Image.open(full_path_in) as image:
59
+ image.save(full_path_out, "JPEG")
60
+ # Return the full path of the saved image
61
+ return full_path_out
62
+ except:
63
+ pass
64
+ else:
65
+ try:
66
+ image.save(full_path_out, "JPEG")
67
+ return full_path_out
68
+ except:
69
+ pass
70
+
71
+
72
+ def save_uploaded_file(directory, img_file, image=None):
73
+ if not os.path.exists(directory):
74
+ os.makedirs(directory)
75
+
76
+ full_path = os.path.join(directory, img_file.name)
77
+
78
+ # Assuming the uploaded file is an image
79
+ if img_file.name.lower().endswith('.pdf'):
80
+ with open(full_path, 'wb') as out_file:
81
+ # If img_file is a file-like object (e.g., Django's UploadedFile),
82
+ # you can use copyfileobj or read chunks.
83
+ # If it's a path, you'd need to open and then save it.
84
+ if hasattr(img_file, 'read'):
85
+ # This is a file-like object
86
+ copyfileobj(img_file, out_file)
87
+ else:
88
+ # If img_file is a path string
89
+ with open(img_file, 'rb') as fd:
90
+ copyfileobj(fd, out_file)
91
+ return full_path
92
+ else:
93
+ if image is None:
94
+ try:
95
+ with Image.open(img_file) as image:
96
+ full_path = os.path.join(directory, img_file.name)
97
+ image.save(full_path, "JPEG")
98
+ # Return the full path of the saved image
99
+ return full_path
100
+ except:
101
+ try:
102
+ with Image.open(os.path.join(directory,img_file)) as image:
103
+ full_path = os.path.join(directory, img_file)
104
+ image.save(full_path, "JPEG")
105
+ # Return the full path of the saved image
106
+ return full_path
107
+ except:
108
+ with Image.open(img_file.name) as image:
109
+ full_path = os.path.join(directory, img_file.name)
110
+ image.save(full_path, "JPEG")
111
+ # Return the full path of the saved image
112
+ return full_path
113
+ else:
114
+ try:
115
+ full_path = os.path.join(directory, img_file.name)
116
+ image.save(full_path, "JPEG")
117
+ return full_path
118
+ except:
119
+ full_path = os.path.join(directory, img_file)
120
+ image.save(full_path, "JPEG")
121
+ return full_path
122
+ # def save_uploaded_file(directory, uploaded_file, image=None):
123
+ # if not os.path.exists(directory):
124
+ # os.makedirs(directory)
125
+
126
+ # full_path = os.path.join(directory, uploaded_file.name)
127
+
128
+ # # Handle PDF files
129
+ # if uploaded_file.name.lower().endswith('.pdf'):
130
+ # with open(full_path, 'wb') as out_file:
131
+ # if hasattr(uploaded_file, 'read'):
132
+ # copyfileobj(uploaded_file, out_file)
133
+ # else:
134
+ # with open(uploaded_file, 'rb') as fd:
135
+ # copyfileobj(fd, out_file)
136
+ # return full_path
137
+ # else:
138
+ # if image is None:
139
+ # try:
140
+ # with Image.open(uploaded_file) as img:
141
+ # img.save(full_path, "JPEG")
142
+ # except:
143
+ # with Image.open(full_path) as img:
144
+ # img.save(full_path, "JPEG")
145
+ # else:
146
+ # try:
147
+ # image.save(full_path, "JPEG")
148
+ # except:
149
+ # image.save(os.path.join(directory, uploaded_file.name), "JPEG")
150
+ # return full_path
151
+
152
+ def save_uploaded_local(directory, img_file, image=None):
153
+ name = img_file.split(os.path.sep)[-1]
154
+ if not os.path.exists(directory):
155
+ os.makedirs(directory)
156
+
157
+ # Assuming the uploaded file is an image
158
+ if image is None:
159
+ with Image.open(img_file) as image:
160
+ full_path = os.path.join(directory, name)
161
+ image.save(full_path, "JPEG")
162
+ # Return the full path of the saved image
163
+ return os.path.join('uploads_small',name)
164
+ else:
165
+ full_path = os.path.join(directory, name)
166
+ image.save(full_path, "JPEG")
167
+ return os.path.join('.','uploads_small',name)
168
+
169
+ def image_to_base64(img):
170
+ buffered = BytesIO()
171
+ img.save(buffered, format="JPEG")
172
+ return base64.b64encode(buffered.getvalue()).decode()
173
+
174
+ def check_prompt_yaml_filename(fname):
175
+ # Check if the filename only contains letters, numbers, underscores, and dashes
176
+ pattern = r'^[\w-]+$'
177
+
178
+ # The \w matches any alphanumeric character and is equivalent to the character class [a-zA-Z0-9_].
179
+ # The hyphen - is literally matched.
180
+
181
+ if re.match(pattern, fname):
182
+ return True
183
+ else:
184
+ return False
185
+
186
+ def report_violation(file_name, is_hf=True):
187
+ # Format the current date and time
188
+ current_time = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
189
+ violation_file_name = f"violation_{current_time}.yaml" # Updated variable name to avoid confusion
190
+
191
+ # Create a temporary YAML file in text mode
192
+ with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.yaml') as temp_file:
193
+ # Example content - customize as needed
194
+ content = {
195
+ 'violation_time': current_time,
196
+ 'notes': 'This is an autogenerated violation report.',
197
+ 'name_of_file': file_name,
198
+ }
199
+ # Write the content to the temporary YAML file in text mode
200
+ yaml.dump(content, temp_file, default_flow_style=False)
201
+ temp_filepath = temp_file.name
202
+
203
+ # Now upload the temporary file
204
+ upload_to_drive(temp_filepath, violation_file_name, is_hf=is_hf)
205
+
206
+ # Optionally, delete the temporary file if you don't want it to remain on disk after uploading
207
+ os.remove(temp_filepath)
208
+
209
+ # Function to upload files to Google Drive
210
+ def upload_to_drive(filepath, filename, is_hf=True, cfg_private=None, do_upload = True):
211
+ if do_upload:
212
+ creds = get_google_credentials(is_hf=is_hf, cfg_private=cfg_private)
213
+ if creds:
214
+ service = build('drive', 'v3', credentials=creds)
215
+
216
+ # Get the folder ID from the environment variable
217
+ if is_hf:
218
+ folder_id = os.environ.get('GDRIVE_FOLDER_ID') # Renamed for clarity
219
+ else:
220
+ folder_id = cfg_private['google']['GDRIVE_FOLDER_ID'] # Renamed for clarity
221
+
222
+
223
+ if folder_id:
224
+ file_metadata = {
225
+ 'name': filename,
226
+ 'parents': [folder_id]
227
+ }
228
+
229
+ # Determine the mimetype based on the file extension
230
+ if filename.endswith('.yaml') or filename.endswith('.yml') or filepath.endswith('.yaml') or filepath.endswith('.yml'):
231
+ mimetype = 'application/x-yaml'
232
+ elif filepath.endswith('.zip'):
233
+ mimetype = 'application/zip'
234
+ else:
235
+ # Set a default mimetype if desired or handle the unsupported file type
236
+ print("Unsupported file type")
237
+ return None
238
+
239
+ # Upload the file
240
+ try:
241
+ media = MediaFileUpload(filepath, mimetype=mimetype)
242
+ file = service.files().create(
243
+ body=file_metadata,
244
+ media_body=media,
245
+ fields='id'
246
+ ).execute()
247
+ print(f"Uploaded file with ID: {file.get('id')}")
248
+ except Exception as e:
249
+ msg = f"If the following error is '404 cannot find file...' then you need to share the GDRIVE folder with your Google API service account's email address. Open your Google API JSON file, find the email account that ends with '@developer.gserviceaccount.com', go to your Google Drive, share the folder with this email account. {e}"
250
+ print(msg)
251
+ raise Exception(msg)
252
+ else:
253
+ print("GDRIVE_API environment variable not set.")
254
+
255
+ def get_google_credentials(is_hf=True, cfg_private=None): # Also used for google drive
256
+ if is_hf:
257
+ creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
258
+ credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
259
+ return credentials
260
+ else:
261
+ with open(cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'], 'r') as file:
262
+ data = json.load(file)
263
+ creds_json_str = json.dumps(data)
264
+ credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
265
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
266
+ return credentials
vouchervision/utils_hf.py CHANGED
@@ -73,7 +73,7 @@ def save_uploaded_file(directory, img_file, image=None):
73
  if not os.path.exists(directory):
74
  os.makedirs(directory)
75
 
76
- full_path = os.path.join(directory, img_file.name)
77
 
78
  # Assuming the uploaded file is an image
79
  if img_file.name.lower().endswith('.pdf'):
@@ -98,11 +98,18 @@ def save_uploaded_file(directory, img_file, image=None):
98
  # Return the full path of the saved image
99
  return full_path
100
  except:
101
- with Image.open(os.path.join(directory,img_file)) as image:
102
- full_path = os.path.join(directory, img_file)
103
- image.save(full_path, "JPEG")
104
- # Return the full path of the saved image
105
- return full_path
 
 
 
 
 
 
 
106
  else:
107
  try:
108
  full_path = os.path.join(directory, img_file.name)
@@ -112,6 +119,35 @@ def save_uploaded_file(directory, img_file, image=None):
112
  full_path = os.path.join(directory, img_file)
113
  image.save(full_path, "JPEG")
114
  return full_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def save_uploaded_local(directory, img_file, image=None):
117
  name = img_file.split(os.path.sep)[-1]
 
73
  if not os.path.exists(directory):
74
  os.makedirs(directory)
75
 
76
+ full_path = os.path.join(directory, img_file.name) ########## TODO THIS MUST BE MOVED TO conditional specific location
77
 
78
  # Assuming the uploaded file is an image
79
  if img_file.name.lower().endswith('.pdf'):
 
98
  # Return the full path of the saved image
99
  return full_path
100
  except:
101
+ try:
102
+ with Image.open(os.path.join(directory,img_file)) as image:
103
+ full_path = os.path.join(directory, img_file)
104
+ image.save(full_path, "JPEG")
105
+ # Return the full path of the saved image
106
+ return full_path
107
+ except:
108
+ with Image.open(img_file.name) as image:
109
+ full_path = os.path.join(directory, img_file.name)
110
+ image.save(full_path, "JPEG")
111
+ # Return the full path of the saved image
112
+ return full_path
113
  else:
114
  try:
115
  full_path = os.path.join(directory, img_file.name)
 
119
  full_path = os.path.join(directory, img_file)
120
  image.save(full_path, "JPEG")
121
  return full_path
122
+ # def save_uploaded_file(directory, uploaded_file, image=None):
123
+ # if not os.path.exists(directory):
124
+ # os.makedirs(directory)
125
+
126
+ # full_path = os.path.join(directory, uploaded_file.name)
127
+
128
+ # # Handle PDF files
129
+ # if uploaded_file.name.lower().endswith('.pdf'):
130
+ # with open(full_path, 'wb') as out_file:
131
+ # if hasattr(uploaded_file, 'read'):
132
+ # copyfileobj(uploaded_file, out_file)
133
+ # else:
134
+ # with open(uploaded_file, 'rb') as fd:
135
+ # copyfileobj(fd, out_file)
136
+ # return full_path
137
+ # else:
138
+ # if image is None:
139
+ # try:
140
+ # with Image.open(uploaded_file) as img:
141
+ # img.save(full_path, "JPEG")
142
+ # except:
143
+ # with Image.open(full_path) as img:
144
+ # img.save(full_path, "JPEG")
145
+ # else:
146
+ # try:
147
+ # image.save(full_path, "JPEG")
148
+ # except:
149
+ # image.save(os.path.join(directory, uploaded_file.name), "JPEG")
150
+ # return full_path
151
 
152
  def save_uploaded_local(directory, img_file, image=None):
153
  name = img_file.split(os.path.sep)[-1]