new-blip / handler.py
pdich2085's picture
Update handler.py
1721131
from PIL import Image
from typing import Dict, Any
import torch
import base64
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to(device)
self.model.eval()
self.max_length = 16
self.num_beams = 4
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
image_data = data.get("inputs", None)
# Convert base64 encoded image string to bytes
image_bytes = base64.b64decode(image_data)
# Convert bytes to a BytesIO object
image_buffer = BytesIO(image_bytes)
# Process the image with the processor
processed_inputs = self.processor(image_buffer, return_tensors="pt").to(device)
# Generate the caption
gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
output_ids = self.model.generate(**processed_inputs, **gen_kwargs)
caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return {"caption": caption}
except Exception as e:
# Log the error for better tracking
print(f"Error during processing: {str(e)}")
return {"caption": "", "error": str(e)}
# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# class EndpointHandler():
# def __init__(self, path=""):
# self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
# self.model = BlipForConditionalGeneration.from_pretrained(
# "Salesforce/blip-image-captioning-large"
# ).to(device)
# self.model.eval()
# self.max_length = 16
# self.num_beams = 4
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# try:
# image_data = data.get("inputs", None)
# # Convert base64 encoded image string to bytes
# image_bytes = base64.b64decode(image_data)
# # Create a BytesIO object from the bytes data
# image_buffer = BytesIO(image_bytes)
# # Open the image from the buffer
# raw_image = Image.open(image_buffer)
# # Ensure the image is in RGB mode (if necessary)
# if raw_image.mode != "RGB":
# raw_image = raw_image.convert(mode="RGB")
# # Extract pixel values and move them to the device
# pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)
# # Generate the caption
# gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
# output_ids = self.model.generate(pixel_values, **gen_kwargs)
# caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
# return {"caption": caption}
# except Exception as e:
# # Log the error for better tracking
# print(f"Error during processing: {str(e)}")
# return {"caption": "", "error": str(e)}
# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# class EndpointHandler():
# def __init__(self, path=""):
# self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
# self.model = BlipForConditionalGeneration.from_pretrained(
# "Salesforce/blip-image-captioning-large"
# ).to(device)
# self.model.eval()
# self.max_length = 16
# self.num_beams = 4
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# try:
# image_bytes = data.get("inputs", None)
# # Convert base64 encoded image string to a PIL Image
# raw_image = Image.open(BytesIO(image_bytes))
# # Ensure the image is in RGB mode (if necessary)
# if raw_image.mode != "RGB":
# raw_image = raw_image.convert(mode="RGB")
# # Extract pixel values and move them to the device
# pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)
# # Generate the caption
# gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
# output_ids = self.model.generate(pixel_values, **gen_kwargs)
# caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
# return {"caption": caption}
# except Exception as e:
# # Log the error for better tracking
# print(f"Error during processing: {str(e)}")
# return {"caption": "", "error": str(e)}