File size: 5,501 Bytes
feb5784
240679c
feb5784
55095dc
feb5784
55095dc
feb5784
55095dc
feb5784
55095dc
a7af85b
55095dc
 
 
 
 
76f6af0
 
83f10c5
240679c
55095dc
27a8481
6e37f66
27a8481
 
 
1721131
 
feb5784
1721131
 
76f6af0
 
 
1721131
76f6af0
1721131
feb5784
55095dc
 
240679c
55095dc
 
27a8481
 
 
 
1721131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27a8481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from PIL import Image
from typing import Dict, Any
import torch
import base64
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        self.model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-large"
        ).to(device)
        self.model.eval()
        self.max_length = 16
        self.num_beams = 4
        
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        try:
            image_data = data.get("inputs", None)
            
            # Convert base64 encoded image string to bytes
            image_bytes = base64.b64decode(image_data)
            
            # Convert bytes to a BytesIO object
            image_buffer = BytesIO(image_bytes)

            # Process the image with the processor
            processed_inputs = self.processor(image_buffer, return_tensors="pt").to(device)

            # Generate the caption
            gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
            output_ids = self.model.generate(**processed_inputs, **gen_kwargs)

            caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

            return {"caption": caption}
        except Exception as e:
            # Log the error for better tracking
            print(f"Error during processing: {str(e)}")
            return {"caption": "", "error": str(e)}




# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class EndpointHandler():
#     def __init__(self, path=""):
#         self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
#         self.model = BlipForConditionalGeneration.from_pretrained(
#             "Salesforce/blip-image-captioning-large"
#         ).to(device)
#         self.model.eval()
#         self.max_length = 16
#         self.num_beams = 4
        
#     def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
#         try:
#             image_data = data.get("inputs", None)
            
#             # Convert base64 encoded image string to bytes
#             image_bytes = base64.b64decode(image_data)

#             # Create a BytesIO object from the bytes data
#             image_buffer = BytesIO(image_bytes)
            
#             # Open the image from the buffer
#             raw_image = Image.open(image_buffer)

#             # Ensure the image is in RGB mode (if necessary)
#             if raw_image.mode != "RGB":
#                 raw_image = raw_image.convert(mode="RGB")

#             # Extract pixel values and move them to the device
#             pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)

#             # Generate the caption
#             gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
#             output_ids = self.model.generate(pixel_values, **gen_kwargs)

#             caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()

#             return {"caption": caption}
#         except Exception as e:
#             # Log the error for better tracking
#             print(f"Error during processing: {str(e)}")
#             return {"caption": "", "error": str(e)}





# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class EndpointHandler():
#     def __init__(self, path=""):
#         self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
#         self.model = BlipForConditionalGeneration.from_pretrained(
#             "Salesforce/blip-image-captioning-large"
#         ).to(device)
#         self.model.eval()
#         self.max_length = 16
#         self.num_beams = 4
        
#     def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
#         try:
#             image_bytes = data.get("inputs", None)
            
#             # Convert base64 encoded image string to a PIL Image
#             raw_image = Image.open(BytesIO(image_bytes))

#             # Ensure the image is in RGB mode (if necessary)
#             if raw_image.mode != "RGB":
#                 raw_image = raw_image.convert(mode="RGB")

#             # Extract pixel values and move them to the device
#             pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)

#             # Generate the caption
#             gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
#             output_ids = self.model.generate(pixel_values, **gen_kwargs)

#             caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()

#             return {"caption": caption}
#         except Exception as e:
#             # Log the error for better tracking
#             print(f"Error during processing: {str(e)}")
#             return {"caption": "", "error": str(e)}