File size: 2,688 Bytes
67f4974
 
 
 
375e6f0
 
 
67f4974
375e6f0
 
 
 
 
 
 
 
67f4974
375e6f0
 
 
 
 
 
 
 
 
67f4974
375e6f0
 
 
 
 
67f4974
375e6f0
 
 
 
 
 
 
 
 
 
 
 
 
67f4974
 
 
 
 
 
 
 
 
 
 
0f82bee
375e6f0
6a10656
67f4974
6a10656
67f4974
 
375e6f0
 
67f4974
 
bc42f94
 
 
375e6f0
67f4974
bc42f94
375e6f0
67f4974
 
 
375e6f0
 
67f4974
375e6f0
67f4974
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
AWS Lambda function
"""

import base64
import json
import logging
from detection import ml_detection, ml_utils


logger = logging.getLogger()
logger.setLevel(logging.INFO)


# Find ML model type based on string request
def get_model_type(query_string):
    """Find ML model type based on string request"""
    # Default ml model type
    if query_string == "":
        model_type = "facebook/detr-resnet-50"
    # Assess query string value
    elif "detr" in query_string:
        model_type = "facebook/" + query_string
    elif "yolos" in query_string:
        model_type = "hustvl/" + query_string
    else:
        raise Exception("Incorrect model type.")
    return model_type


# Run detection pipeline: load ML model, perform object detection and return json object
def detection_pipeline(model_type, image_bytes):
    """detection pipeline: load ML model, perform object detection and return json object"""
    # Load correct ML model
    processor, model = ml_detection.load_model(model_type)

    # Perform object detection
    results = ml_detection.object_detection(processor, model, image_bytes)

    # Convert dictionary of tensors to JSON object
    result_json_dict = ml_utils.convert_tensor_dict_to_json(results)

    return result_json_dict


def lambda_handler(event, context):
    """
        Lambda handler (proxy integration option unchecked on AWS API Gateway)

        Args:
            event (dict): The event that triggered the Lambda function.
            context (LambdaContext): Information about the execution environment.

        Returns:
            dict: The response to be returned from the Lambda function.
        """

    # logger.info(f"API event: {event}")
    try:
        # Retrieve model type
        model_query = event.get("model", "")
        model_type = get_model_type(model_query)
        logger.info("Model query: %s", model_query)
        logger.info("Model type: %s", model_type)

        # Decode the base64-encoded image data from the event
        image_data = event["body"]
        if event["isBase64Encoded"]:
            image_data = base64.b64decode(image_data)

        # Run detection pipeline
        result_dict = detection_pipeline(model_type, image_data)
        logger.info("API Results: %s", str(result_dict))

        return {
            "statusCode": 200,
            "headers": {"Content-Type": "application/json"},
            "body": json.dumps(result_dict),
        }
    except Exception as e:
        logger.info("API Error: %s", str(e))
        return {
            "statusCode": 500,
            "headers": {"Content-Type": "application/json"},
            "body": json.dumps({"error": str(e)}),
        }