Spaces:
Running
Running
File size: 2,688 Bytes
67f4974 375e6f0 67f4974 375e6f0 67f4974 375e6f0 67f4974 375e6f0 67f4974 375e6f0 67f4974 0f82bee 375e6f0 6a10656 67f4974 6a10656 67f4974 375e6f0 67f4974 bc42f94 375e6f0 67f4974 bc42f94 375e6f0 67f4974 375e6f0 67f4974 375e6f0 67f4974 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""
AWS Lambda function
"""
import base64
import json
import logging
from detection import ml_detection, ml_utils
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Find ML model type based on string request
def get_model_type(query_string):
"""Find ML model type based on string request"""
# Default ml model type
if query_string == "":
model_type = "facebook/detr-resnet-50"
# Assess query string value
elif "detr" in query_string:
model_type = "facebook/" + query_string
elif "yolos" in query_string:
model_type = "hustvl/" + query_string
else:
raise Exception("Incorrect model type.")
return model_type
# Run detection pipeline: load ML model, perform object detection and return json object
def detection_pipeline(model_type, image_bytes):
"""detection pipeline: load ML model, perform object detection and return json object"""
# Load correct ML model
processor, model = ml_detection.load_model(model_type)
# Perform object detection
results = ml_detection.object_detection(processor, model, image_bytes)
# Convert dictionary of tensors to JSON object
result_json_dict = ml_utils.convert_tensor_dict_to_json(results)
return result_json_dict
def lambda_handler(event, context):
"""
Lambda handler (proxy integration option unchecked on AWS API Gateway)
Args:
event (dict): The event that triggered the Lambda function.
context (LambdaContext): Information about the execution environment.
Returns:
dict: The response to be returned from the Lambda function.
"""
# logger.info(f"API event: {event}")
try:
# Retrieve model type
model_query = event.get("model", "")
model_type = get_model_type(model_query)
logger.info("Model query: %s", model_query)
logger.info("Model type: %s", model_type)
# Decode the base64-encoded image data from the event
image_data = event["body"]
if event["isBase64Encoded"]:
image_data = base64.b64decode(image_data)
# Run detection pipeline
result_dict = detection_pipeline(model_type, image_data)
logger.info("API Results: %s", str(result_dict))
return {
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(result_dict),
}
except Exception as e:
logger.info("API Error: %s", str(e))
return {
"statusCode": 500,
"headers": {"Content-Type": "application/json"},
"body": json.dumps({"error": str(e)}),
}
|