Qwen2-VL-7B-Sydney / handler.py
adamo1139's picture
Upload handler.py
2701602 verified
from typing import Dict, Any, List
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, path: str = "") -> None:
# Load the Qwen2-VL-7B-Instruct model on available devices.
# The torch_dtype is set to "auto" and device_map="auto" for optimal device usage.
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path,
torch_dtype="auto",
device_map="auto"
)
# Load the default processor which handles text formatting, image resizing,
# and optionally video preprocessing for Qwen2-VL.
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Extract the conversation messages from the input data.
messages = data.get("messages")
if messages is None:
raise ValueError("Input data must contain a 'messages' key with conversation data.")
# Create the text prompt using the processor’s chat template function.
# This will add necessary system and generation prompts.
text_prompt = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process any visual inputs (images and/or videos) from the messages.
# The helper function from qwen_vl_utils handles various formats (URLs, base64, local files).
image_inputs, video_inputs = process_vision_info(messages)
# Prepare a dictionary of model inputs.
inputs = self.processor(
text=[text_prompt],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
# Move all tensors to the device where the model is loaded.
inputs = inputs.to(self.model.device)
# Use the model's generate() method to produce output.
# You can pass an optional "max_new_tokens" parameter from the input data.
max_new_tokens = data.get("max_new_tokens", 128)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Remove the input prompt tokens from the generated sequence.
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode the token ids to obtain the final text output.
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return {"output": output_text}