seanpedrickcase commited on
Commit
1cfa6e8
1 Parent(s): 6622361

Moved gradio run code to outside of lambda_handler function in lambda_entrypoint.py

Browse files
Files changed (1) hide show
  1. lambda_entrypoint.py +78 -80
lambda_entrypoint.py CHANGED
@@ -15,6 +15,16 @@ TMP_DIR = "/tmp/"
15
 
16
  run_direct_mode = os.getenv("RUN_DIRECT_MODE", "0")
17
 
 
 
 
 
 
 
 
 
 
 
18
  def download_file_from_s3(bucket_name, key, download_path):
19
  """Download a file from S3 to the local filesystem."""
20
  s3_client.download_file(bucket_name, key, download_path)
@@ -28,85 +38,73 @@ def upload_file_to_s3(file_path, bucket_name, key):
28
  def lambda_handler(event, context):
29
 
30
  print("In lambda_handler function")
31
-
32
- if run_direct_mode == "0":
33
- # Gradio App execution
34
- from app import app, max_queue_size, max_file_size # Replace with actual import if needed
35
- from tools.auth import authenticate_user
36
-
37
- if os.getenv("COGNITO_AUTH", "0") == "1":
38
- app.queue(max_size=max_queue_size).launch(show_error=True, auth=authenticate_user, max_file_size=max_file_size)
39
- else:
40
- app.queue(max_size=max_queue_size).launch(show_error=True, inbrowser=True, max_file_size=max_file_size)
41
-
42
- else:
43
-
44
- # Create necessary directories
45
- os.makedirs(os.path.join(TMP_DIR, "input"), exist_ok=True)
46
- os.makedirs(os.path.join(TMP_DIR, "output"), exist_ok=True)
47
-
48
- print("Got to record loop")
49
- print("Event records is:", event["Records"])
50
-
51
- # Extract S3 bucket and object key from the Records
52
- for record in event.get("Records", [{}]):
53
- bucket_name = record.get("s3", {}).get("bucket", {}).get("name")
54
- input_key = record.get("s3", {}).get("object", {}).get("key")
55
- print(f"Processing file {input_key} from bucket {bucket_name}")
56
-
57
- # Extract additional arguments
58
- arguments = event.get("arguments", {})
59
-
60
- if not input_key:
61
- input_key = arguments.get("input_file", "")
62
-
63
- ocr_method = arguments.get("ocr_method", "Complex image analysis - docs with handwriting/signatures (AWS Textract)")
64
- pii_detector = arguments.get("pii_detector", "AWS Comprehend")
65
- page_min = str(arguments.get("page_min", 0))
66
- page_max = str(arguments.get("page_max", 0))
67
- allow_list = arguments.get("allow_list", None)
68
- output_dir = arguments.get("output_dir", os.path.join(TMP_DIR, "output"))
69
-
70
- print(f"OCR Method: {ocr_method}")
71
- print(f"PII Detector: {pii_detector}")
72
- print(f"Page Range: {page_min} - {page_max}")
73
- print(f"Allow List: {allow_list}")
74
- print(f"Output Directory: {output_dir}")
75
-
76
- # Download input file
77
- input_file_path = os.path.join(TMP_DIR, "input", os.path.basename(input_key))
78
- download_file_from_s3(bucket_name, input_key, input_file_path)
79
-
80
- # Construct command
81
- command = [
82
- "python",
83
- "app.py",
84
- "--input_file", input_file_path,
85
- "--ocr_method", ocr_method,
86
- "--pii_detector", pii_detector,
87
- "--page_min", page_min,
88
- "--page_max", page_max,
89
- "--output_dir", output_dir,
90
- ]
91
-
92
- # Add allow_list only if provided
93
- if allow_list:
94
- allow_list_path = os.path.join(TMP_DIR, "allow_list.csv")
95
- download_file_from_s3(bucket_name, allow_list, allow_list_path)
96
- command.extend(["--allow_list", allow_list_path])
97
-
98
- try:
99
- result = subprocess.run(command, capture_output=True, text=True, check=True)
100
- print("Processing succeeded:", result.stdout)
101
- except subprocess.CalledProcessError as e:
102
- print("Error during processing:", e.stderr)
103
- raise e
104
-
105
- # Upload output files back to S3
106
- for root, _, files in os.walk(output_dir):
107
- for file_name in files:
108
- local_file_path = os.path.join(root, file_name)
109
- output_key = f"{os.path.dirname(input_key)}/output/{file_name}"
110
- upload_file_to_s3(local_file_path, bucket_name, output_key)
111
 
112
  return {"statusCode": 200, "body": "Processing complete."}
 
15
 
16
  run_direct_mode = os.getenv("RUN_DIRECT_MODE", "0")
17
 
18
+ if run_direct_mode == "0":
19
+ # Gradio App execution
20
+ from app import app, max_queue_size, max_file_size # Replace with actual import if needed
21
+ from tools.auth import authenticate_user
22
+
23
+ if os.getenv("COGNITO_AUTH", "0") == "1":
24
+ app.queue(max_size=max_queue_size).launch(show_error=True, auth=authenticate_user, max_file_size=max_file_size)
25
+ else:
26
+ app.queue(max_size=max_queue_size).launch(show_error=True, inbrowser=True, max_file_size=max_file_size)
27
+
28
  def download_file_from_s3(bucket_name, key, download_path):
29
  """Download a file from S3 to the local filesystem."""
30
  s3_client.download_file(bucket_name, key, download_path)
 
38
  def lambda_handler(event, context):
39
 
40
  print("In lambda_handler function")
41
+
42
+ # Create necessary directories
43
+ os.makedirs(os.path.join(TMP_DIR, "input"), exist_ok=True)
44
+ os.makedirs(os.path.join(TMP_DIR, "output"), exist_ok=True)
45
+
46
+ print("Got to record loop")
47
+ print("Event records is:", event["Records"])
48
+
49
+ # Extract S3 bucket and object key from the Records
50
+ for record in event.get("Records", [{}]):
51
+ bucket_name = record.get("s3", {}).get("bucket", {}).get("name")
52
+ input_key = record.get("s3", {}).get("object", {}).get("key")
53
+ print(f"Processing file {input_key} from bucket {bucket_name}")
54
+
55
+ # Extract additional arguments
56
+ arguments = event.get("arguments", {})
57
+
58
+ if not input_key:
59
+ input_key = arguments.get("input_file", "")
60
+
61
+ ocr_method = arguments.get("ocr_method", "Complex image analysis - docs with handwriting/signatures (AWS Textract)")
62
+ pii_detector = arguments.get("pii_detector", "AWS Comprehend")
63
+ page_min = str(arguments.get("page_min", 0))
64
+ page_max = str(arguments.get("page_max", 0))
65
+ allow_list = arguments.get("allow_list", None)
66
+ output_dir = arguments.get("output_dir", os.path.join(TMP_DIR, "output"))
67
+
68
+ print(f"OCR Method: {ocr_method}")
69
+ print(f"PII Detector: {pii_detector}")
70
+ print(f"Page Range: {page_min} - {page_max}")
71
+ print(f"Allow List: {allow_list}")
72
+ print(f"Output Directory: {output_dir}")
73
+
74
+ # Download input file
75
+ input_file_path = os.path.join(TMP_DIR, "input", os.path.basename(input_key))
76
+ download_file_from_s3(bucket_name, input_key, input_file_path)
77
+
78
+ # Construct command
79
+ command = [
80
+ "python",
81
+ "app.py",
82
+ "--input_file", input_file_path,
83
+ "--ocr_method", ocr_method,
84
+ "--pii_detector", pii_detector,
85
+ "--page_min", page_min,
86
+ "--page_max", page_max,
87
+ "--output_dir", output_dir,
88
+ ]
89
+
90
+ # Add allow_list only if provided
91
+ if allow_list:
92
+ allow_list_path = os.path.join(TMP_DIR, "allow_list.csv")
93
+ download_file_from_s3(bucket_name, allow_list, allow_list_path)
94
+ command.extend(["--allow_list", allow_list_path])
95
+
96
+ try:
97
+ result = subprocess.run(command, capture_output=True, text=True, check=True)
98
+ print("Processing succeeded:", result.stdout)
99
+ except subprocess.CalledProcessError as e:
100
+ print("Error during processing:", e.stderr)
101
+ raise e
102
+
103
+ # Upload output files back to S3
104
+ for root, _, files in os.walk(output_dir):
105
+ for file_name in files:
106
+ local_file_path = os.path.join(root, file_name)
107
+ output_key = f"{os.path.dirname(input_key)}/output/{file_name}"
108
+ upload_file_to_s3(local_file_path, bucket_name, output_key)
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  return {"statusCode": 200, "body": "Processing complete."}