pdich2085 commited on
Commit
1721131
1 Parent(s): 27a8481

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +60 -13
handler.py CHANGED
@@ -23,25 +23,18 @@ class EndpointHandler():
23
 
24
  # Convert base64 encoded image string to bytes
25
  image_bytes = base64.b64decode(image_data)
26
-
27
- # Create a BytesIO object from the bytes data
28
- image_buffer = BytesIO(image_bytes)
29
 
30
- # Open the image from the buffer
31
- raw_image = Image.open(image_buffer)
32
-
33
- # Ensure the image is in RGB mode (if necessary)
34
- if raw_image.mode != "RGB":
35
- raw_image = raw_image.convert(mode="RGB")
36
 
37
- # Extract pixel values and move them to the device
38
- pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)
39
 
40
  # Generate the caption
41
  gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
42
- output_ids = self.model.generate(pixel_values, **gen_kwargs)
43
 
44
- caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
45
 
46
  return {"caption": caption}
47
  except Exception as e:
@@ -52,6 +45,60 @@ class EndpointHandler():
52
 
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # from PIL import Image
57
  # from typing import Dict, Any
 
23
 
24
  # Convert base64 encoded image string to bytes
25
  image_bytes = base64.b64decode(image_data)
 
 
 
26
 
27
+ # Convert bytes to a BytesIO object
28
+ image_buffer = BytesIO(image_bytes)
 
 
 
 
29
 
30
+ # Process the image with the processor
31
+ processed_inputs = self.processor(image_buffer, return_tensors="pt").to(device)
32
 
33
  # Generate the caption
34
  gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
35
+ output_ids = self.model.generate(**processed_inputs, **gen_kwargs)
36
 
37
+ caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
38
 
39
  return {"caption": caption}
40
  except Exception as e:
 
45
 
46
 
47
 
48
+ # from PIL import Image
49
+ # from typing import Dict, Any
50
+ # import torch
51
+ # import base64
52
+ # from io import BytesIO
53
+ # from transformers import BlipForConditionalGeneration, BlipProcessor
54
+
55
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
+
57
+ # class EndpointHandler():
58
+ # def __init__(self, path=""):
59
+ # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
60
+ # self.model = BlipForConditionalGeneration.from_pretrained(
61
+ # "Salesforce/blip-image-captioning-large"
62
+ # ).to(device)
63
+ # self.model.eval()
64
+ # self.max_length = 16
65
+ # self.num_beams = 4
66
+
67
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
68
+ # try:
69
+ # image_data = data.get("inputs", None)
70
+
71
+ # # Convert base64 encoded image string to bytes
72
+ # image_bytes = base64.b64decode(image_data)
73
+
74
+ # # Create a BytesIO object from the bytes data
75
+ # image_buffer = BytesIO(image_bytes)
76
+
77
+ # # Open the image from the buffer
78
+ # raw_image = Image.open(image_buffer)
79
+
80
+ # # Ensure the image is in RGB mode (if necessary)
81
+ # if raw_image.mode != "RGB":
82
+ # raw_image = raw_image.convert(mode="RGB")
83
+
84
+ # # Extract pixel values and move them to the device
85
+ # pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)
86
+
87
+ # # Generate the caption
88
+ # gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
89
+ # output_ids = self.model.generate(pixel_values, **gen_kwargs)
90
+
91
+ # caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
92
+
93
+ # return {"caption": caption}
94
+ # except Exception as e:
95
+ # # Log the error for better tracking
96
+ # print(f"Error during processing: {str(e)}")
97
+ # return {"caption": "", "error": str(e)}
98
+
99
+
100
+
101
+
102
 
103
  # from PIL import Image
104
  # from typing import Dict, Any