from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import supervision as sv import cv2 import numpy as np from PIL import Image import gradio as gr import spaces from helpers.file_utils import create_directory, delete_directory, generate_unique_name from helpers.segment_utils import parse_segmentation, extract_objs import os BOX_ANNOTATOR = sv.BoxAnnotator() LABEL_ANNOTATOR = sv.LabelAnnotator() MASK_ANNOTATOR = sv.MaskAnnotator() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") VIDEO_TARGET_DIRECTORY = "tmp" VAE_MODEL = "vae-oid.npz" COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] INTRO_TEXT = """ ## PaliGemma 2 Detection/Segmentation with Supervision - Demo
Github Huggingface Colab Paper Supervision
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question answering, text reading, object detection and object segmentation. This space show how to use PaliGemma 2 for object detection with supervision. You can input an image and a text prompt """ create_directory(directory_path=VIDEO_TARGET_DIRECTORY) model_id = "google/paligemma2-3b-pt-448" model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE) processor = PaliGemmaProcessor.from_pretrained(model_id) def parse_class_names(prompt): if not prompt.lower().startswith('detect '): return [] classes_text = prompt[7:].strip() return [cls.strip() for cls in classes_text.split(';') if cls.strip()] def parse_prompt_type(prompt): """Determine if the prompt is for detection or segmentation.""" if prompt.lower().startswith('detect '): return 'detection', prompt[7:].strip() elif prompt.lower().startswith('segment '): return 'segmentation', prompt[8:].strip() return None, prompt @spaces.GPU def paligemma_detection(input_image, input_text, max_new_tokens): model_inputs = processor(text=input_text, images=input_image, return_tensors="pt" ).to(torch.bfloat16).to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) generation = generation[0][input_len:] result = processor.decode(generation, skip_special_tokens=True) return result def annotate_image(result, resolution_wh, prompt, cv_image): class_names = parse_class_names(prompt) if not class_names: gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") return cv_image detections = sv.Detections.from_lmm( sv.LMM.PALIGEMMA, result, resolution_wh=resolution_wh, classes=class_names ) annotated_image = BOX_ANNOTATOR.annotate( scene=cv_image.copy(), detections=detections ) annotated_image = LABEL_ANNOTATOR.annotate( scene=annotated_image, detections=detections ) annotated_image = MASK_ANNOTATOR.annotate( scene=annotated_image, detections=detections ) annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) annotated_image = Image.fromarray(annotated_image) return annotated_image def process_image(input_image, input_text, max_new_tokens): cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) prompt_type, cleaned_prompt = parse_prompt_type(input_text) if prompt_type == 'detection': # Existing detection logic result = paligemma_detection(input_image, input_text, max_new_tokens) class_names = [cls.strip() for cls in cleaned_prompt.split(';') if cls.strip()] detections = sv.Detections.from_lmm( sv.LMM.PALIGEMMA, result, resolution_wh=(input_image.width, input_image.height), classes=class_names ) annotated_image = BOX_ANNOTATOR.annotate(scene=cv_image.copy(), detections=detections) annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections) annotated_image = MASK_ANNOTATOR.annotate(scene=annotated_image, detections=detections) elif prompt_type == 'segmentation': # Use parse_segmentation for segmentation tasks result = paligemma_detection(input_image, input_text, max_new_tokens) input_image, annotations = parse_segmentation(input_image, result) # Create annotated image annotated_image = cv_image.copy() for mask, label in annotations: if isinstance(mask, np.ndarray): # If it's a segmentation mask # Create colored mask color_idx = hash(label) % len(COLORS) color = tuple(int(COLORS[color_idx].lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) colored_mask = np.zeros_like(cv_image) colored_mask[mask > 0] = color # Blend mask with image alpha = 0.5 annotated_image = cv2.addWeighted(annotated_image, 1, colored_mask, alpha, 0) # Add label where mask starts y_coords, x_coords = np.where(mask > 0) if len(y_coords) > 0 and len(x_coords) > 0: label_y = y_coords.min() label_x = x_coords.min() cv2.putText(annotated_image, label, (label_x, label_y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) else: gr.Warning("Invalid prompt format. Please use 'detect' or 'segment' followed by class names") return input_image, "Invalid prompt format" # Convert back to RGB for display annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) annotated_image = Image.fromarray(annotated_image) return annotated_image, result @spaces.GPU def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(track_tqdm=True)): if not input_video: gr.Info("Please upload a video.") return None if not input_text: gr.Info("Please enter a text prompt.") return None class_names = parse_class_names(input_text) if not class_names: gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") return None, None name = generate_unique_name() frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) create_directory(frame_directory_path) video_info = sv.VideoInfo.from_video_path(input_video) frame_generator = sv.get_video_frames_generator(input_video) video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") results = [] with sv.VideoSink(video_path, video_info=video_info) as sink: for frame in progress.tqdm(frame_generator, desc="Processing video"): pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) model_inputs = processor( text=input_text, images=pil_frame, return_tensors="pt" ).to(torch.bfloat16).to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) generation = generation[0][input_len:] result = processor.decode(generation, skip_special_tokens=True) detections = sv.Detections.from_lmm( sv.LMM.PALIGEMMA, result, resolution_wh=(video_info.width, video_info.height), classes=class_names ) annotated_frame = BOX_ANNOTATOR.annotate( scene=frame.copy(), detections=detections ) annotated_frame = LABEL_ANNOTATOR.annotate( scene=annotated_frame, detections=detections ) annotated_frame = MASK_ANNOTATOR.annotate( scene=annotated_frame, detections=detections ) results.append(result) sink.write_frame(annotated_frame) delete_directory(frame_directory_path) return video_path, results with gr.Blocks() as app: gr.Markdown(INTRO_TEXT) with gr.Tab("Image Detection/Segmentation"): with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") input_text = gr.Textbox( lines=2, placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building", label="Enter detection prompt" ) max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.") with gr.Column(): annotated_image = gr.Image(type="pil", label="Annotated Image") detection_result = gr.Textbox(label="Detection Result") gr.Button("Submit").click( fn=process_image, inputs=[input_image, input_text, max_new_tokens], outputs=[annotated_image, detection_result] ) with gr.Tab("Video Detection"): with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video") input_text = gr.Textbox( lines=2, placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building", label="Enter detection prompt" ) max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.") with gr.Column(): output_video = gr.Video(label="Annotated Video") detection_result = gr.Textbox(label="Detection Result") gr.Button("Process Video").click( fn=process_video, inputs=[input_video, input_text, max_new_tokens], outputs=[output_video, detection_result] ) if __name__ == "__main__": app.launch(ssr_mode=False)