from typing import Dict, Any from transformers import QwenImageProcessor, QwenTokenizer, QwenForMultiModalConditionalGeneration import torch from PIL import Image import io import json import base64 import requests class EndpointHandler(): def __init__(self, path=""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = QwenForMultiModalConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32 ).to(self.device) self.image_processor = QwenImageProcessor.from_pretrained(path) self.tokenizer = QwenTokenizer.from_pretrained(path) self.model.generation_config.use_cache = False def __call__(self, data: Any) -> Dict[str, Any]: """ Args: data (Any): The input data, which can be: - Binary image data in the request body. - A dictionary with 'image' and 'text' keys: - 'image': Base64-encoded image string or image URL. - 'text': The text prompt. Returns: Dict[str, Any]: The generated text output from the model. """ if isinstance(data, (bytes, bytearray)): image = Image.open(io.BytesIO(data)).convert('RGB') text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n" elif isinstance(data, dict): image_input = data.get('image', None) text_input = data.get('text', '') if image_input is None: return {"error": "No image provided."} if image_input.startswith('http'): response = requests.get(image_input) image = Image.open(io.BytesIO(response.content)).convert('RGB') else: image_data = base64.b64decode(image_input) image = Image.open(io.BytesIO(image_data)).convert('RGB') else: return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."} image_inputs = self.image_processor(images=image, return_tensors="pt").to(self.device) if not text_input: text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n" input_ids = self.tokenizer(text_input, return_tensors="pt").input_ids.to(self.device) generated_ids = self.model.generate( **image_inputs, input_ids=input_ids, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=0.7, ) output_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) return {"generated_text": output_text}