phyloforfun commited on
Commit
b2935d3
1 Parent(s): 5713f25

fix safety check

Browse files
Files changed (2) hide show
  1. app.py +2 -3
  2. vouchervision/OCR_google_cloud_vision.py +9 -58
app.py CHANGED
@@ -262,12 +262,11 @@ def handle_image_upload_and_gallery_hf(uploaded_files):
262
 
263
  ind_small = 0
264
  for uploaded_file in uploaded_files:
265
-
266
  if SAFE.check_for_inappropriate_content(uploaded_file):
267
  clear_image_uploads()
268
  report_violation(uploaded_file.name, is_hf=st.session_state['is_hf'])
269
  st.error("Warning: You uploaded an image that violates our terms of service.")
270
- return True
271
 
272
 
273
  # Determine the file type
@@ -391,7 +390,7 @@ def content_input_images(col_left, col_right):
391
 
392
  with col_right:
393
  if st.session_state.is_hf:
394
- result = handle_image_upload_and_gallery_hf(uploaded_files)
395
 
396
  else:
397
  st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
 
262
 
263
  ind_small = 0
264
  for uploaded_file in uploaded_files:
265
+ uploaded_file.seek(0)
266
  if SAFE.check_for_inappropriate_content(uploaded_file):
267
  clear_image_uploads()
268
  report_violation(uploaded_file.name, is_hf=st.session_state['is_hf'])
269
  st.error("Warning: You uploaded an image that violates our terms of service.")
 
270
 
271
 
272
  # Determine the file type
 
390
 
391
  with col_right:
392
  if st.session_state.is_hf:
393
+ handle_image_upload_and_gallery_hf(uploaded_files)
394
 
395
  else:
396
  st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
vouchervision/OCR_google_cloud_vision.py CHANGED
@@ -792,58 +792,6 @@ class OCREngine:
792
  except:
793
  pass
794
 
795
- # class SafetyCheck():
796
- # def __init__(self, is_hf) -> None:
797
- # self.is_hf = is_hf
798
- # self.set_client()
799
-
800
- # def set_client(self):
801
- # if self.is_hf:
802
- # self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
803
- # else:
804
- # self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
805
-
806
-
807
- # def get_google_credentials(self):
808
- # creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
809
- # credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
810
- # return credentials
811
-
812
- # def check_for_inappropriate_content(self, file_stream):
813
- # LEVEL = 2
814
- # content = file_stream.read()
815
- # image = vision.Image(content=content)
816
- # response = self.client.safe_search_detection(image=image)
817
- # safe = response.safe_search_annotation
818
-
819
- # likelihood_name = (
820
- # "UNKNOWN",
821
- # "VERY_UNLIKELY",
822
- # "UNLIKELY",
823
- # "POSSIBLE",
824
- # "LIKELY",
825
- # "VERY_LIKELY",
826
- # )
827
- # print("Safe search:")
828
-
829
- # print(f" adult*: {likelihood_name[safe.adult]}")
830
- # print(f" medical*: {likelihood_name[safe.medical]}")
831
- # print(f" spoofed: {likelihood_name[safe.spoof]}")
832
- # print(f" violence*: {likelihood_name[safe.violence]}")
833
- # print(f" racy: {likelihood_name[safe.racy]}")
834
-
835
- # # Check the levels of adult, violence, racy, etc. content.
836
- # if (safe.adult > LEVEL or
837
- # safe.medical > LEVEL or
838
- # # safe.spoof > LEVEL or
839
- # safe.violence > LEVEL #or
840
- # # safe.racy > LEVEL
841
- # ):
842
- # print("Found violation")
843
- # return True # The image violates safe search guidelines.
844
-
845
- # print("Found NO violation")
846
- # return False # The image is considered safe.
847
  class SafetyCheck():
848
  def __init__(self, is_hf) -> None:
849
  self.is_hf = is_hf
@@ -859,15 +807,14 @@ class SafetyCheck():
859
  creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
860
  credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
861
  return credentials
862
-
863
  def check_for_inappropriate_content(self, file_stream):
864
  LEVEL = 2
 
 
865
  content = file_stream.read()
866
  image = vision.Image(content=content)
867
- feature = vision.Feature(type=vision.Feature.Type.SAFE_SEARCH_DETECTION)
868
-
869
- request = vision.AnnotateImageRequest(image=image, features=[feature])
870
- response = self.client.annotate_image(request=request)
871
  safe = response.safe_search_annotation
872
 
873
  likelihood_name = (
@@ -879,6 +826,7 @@ class SafetyCheck():
879
  "VERY_LIKELY",
880
  )
881
  print("Safe search:")
 
882
  print(f" adult*: {likelihood_name[safe.adult]}")
883
  print(f" medical*: {likelihood_name[safe.medical]}")
884
  print(f" spoofed: {likelihood_name[safe.spoof]}")
@@ -888,7 +836,10 @@ class SafetyCheck():
888
  # Check the levels of adult, violence, racy, etc. content.
889
  if (safe.adult > LEVEL or
890
  safe.medical > LEVEL or
891
- safe.violence > LEVEL):
 
 
 
892
  print("Found violation")
893
  return True # The image violates safe search guidelines.
894
 
 
792
  except:
793
  pass
794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  class SafetyCheck():
796
  def __init__(self, is_hf) -> None:
797
  self.is_hf = is_hf
 
807
  creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
808
  credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
809
  return credentials
810
+
811
  def check_for_inappropriate_content(self, file_stream):
812
  LEVEL = 2
813
+ # content = file_stream.read()
814
+ file_stream.seek(0) # Reset file stream position to the beginning
815
  content = file_stream.read()
816
  image = vision.Image(content=content)
817
+ response = self.client.safe_search_detection(image=image)
 
 
 
818
  safe = response.safe_search_annotation
819
 
820
  likelihood_name = (
 
826
  "VERY_LIKELY",
827
  )
828
  print("Safe search:")
829
+
830
  print(f" adult*: {likelihood_name[safe.adult]}")
831
  print(f" medical*: {likelihood_name[safe.medical]}")
832
  print(f" spoofed: {likelihood_name[safe.spoof]}")
 
836
  # Check the levels of adult, violence, racy, etc. content.
837
  if (safe.adult > LEVEL or
838
  safe.medical > LEVEL or
839
+ # safe.spoof > LEVEL or
840
+ safe.violence > LEVEL #or
841
+ # safe.racy > LEVEL
842
+ ):
843
  print("Found violation")
844
  return True # The image violates safe search guidelines.
845