import torch import gradio as gr import re import cv2 from PIL import ImageDraw, Image from transformers import AutoProcessor, PaliGemmaForConditionalGeneration mix_model_id = "google/paligemma-3b-mix-224" mix_model = PaliGemmaForConditionalGeneration.from_pretrained(mix_model_id) mix_processor = AutoProcessor.from_pretrained(mix_model_id) # Helper function to parse multiple tags and return a list of coordinate sets and labels def parse_multiple_locations(decoded_output): # Regex pattern to match four tags and the label at the end (e.g., 'cat') loc_pattern = r"\s+(\w+)" matches = re.findall(loc_pattern, decoded_output) coords_and_labels = [] for match in matches: # Extract the coordinates and label y1 = int(match[0]) / 1000 x1 = int(match[1]) / 1000 y2 = int(match[2]) / 1000 x2 = int(match[3]) / 1000 label = match[4] coords_and_labels.append({ 'label': label, 'bbox': [y1, x1, y2, x2] }) return coords_and_labels # Helper function to draw bounding boxes and labels for all objects on the image def draw_multiple_bounding_boxes(image, coords_and_labels): draw = ImageDraw.Draw(image) width, height = image.size for obj in coords_and_labels: # Extract the bounding box coordinates y1, x1, y2, x2 = obj['bbox'][0] * height, obj['bbox'][1] * width, obj['bbox'][2] * height, obj['bbox'][3] * width # Draw bounding box and label draw.rectangle([x1, y1, x2, y2], outline="red", width=3) draw.text((x1, y1), obj['label'], fill="red") return image # Define inference function def process_image(image, prompt): # Process the image and prompt using the processor inputs = mix_processor(image.convert("RGB"), prompt, return_tensors="pt") try: # Generate output from the model output = mix_model.generate(**inputs, max_new_tokens=100) # Decode the output from the model decoded_output = mix_processor.decode(output[0], skip_special_tokens=True) # Extract bounding box coordinates and labels coords_and_labels = parse_multiple_locations(decoded_output) if coords_and_labels: # Draw bounding boxes and labels on the image image_with_boxes = draw_multiple_bounding_boxes(image, coords_and_labels) # Prepare the coordinates and labels for the UI labels_and_coords = "\n".join([f"Label: {obj['label']}, Coordinates: {obj['bbox']}" for obj in coords_and_labels]) # Return the modified image and the list of coordinates+labels return image_with_boxes, labels_and_coords else: return "No bounding boxes detected." except IndexError as e: print(f"IndexError: {e}") return "An error occurred during processing." # Define the Gradio interface inputs = [ gr.Image(type="pil"), gr.Textbox(label="Prompt", placeholder="Enter your question") ] outputs = [ gr.Image(label="Output Image with Bounding Boxes"), gr.Textbox(label="Bounding Box Coordinates and Labels") ] # Create the Gradio app demo = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Object Detection with Mix PaliGemma Model", description="Upload an image and get object detections with bounding boxes and labels.") # Launch the app demo.launch(debug=True)