from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import supervision as sv import cv2 import numpy as np from PIL import Image import gradio as gr import spaces from helpers.utils import create_directory, delete_directory, generate_unique_name import os BOX_ANNOTATOR = sv.BoxAnnotator() LABEL_ANNOTATOR = sv.LabelAnnotator() MASK_ANNOTATOR = sv.MaskAnnotator() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") VIDEO_TARGET_DIRECTORY = "tmp" INTRO_TEXT = """ ## PaliGemma 2 Detection with Supervision - Demo
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question answering, text reading, object detection and object segmentation. This space show how to use PaliGemma 2 for object detection with supervision. You can input an image and a text prompt """ create_directory(directory_path=VIDEO_TARGET_DIRECTORY) model_id = "google/paligemma2-3b-pt-448" model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE) processor = PaliGemmaProcessor.from_pretrained(model_id) def parse_class_names(prompt): if not prompt.lower().startswith('detect '): return [] classes_text = prompt[7:].strip() return [cls.strip() for cls in classes_text.split(';') if cls.strip()] @spaces.GPU def paligemma_detection(input_image, input_text, max_new_tokens): model_inputs = processor(text=input_text, images=input_image, return_tensors="pt" ).to(torch.bfloat16).to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) generation = generation[0][input_len:] result = processor.decode(generation, skip_special_tokens=True) return result def annotate_image(result, resolution_wh, prompt, cv_image): class_names = parse_class_names(prompt) if not class_names: gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") return cv_image detections = sv.Detections.from_lmm( sv.LMM.PALIGEMMA, result, resolution_wh=resolution_wh, classes=class_names ) annotated_image = BOX_ANNOTATOR.annotate( scene=cv_image.copy(), detections=detections ) annotated_image = LABEL_ANNOTATOR.annotate( scene=annotated_image, detections=detections ) annotated_image = MASK_ANNOTATOR.annotate( scene=annotated_image, detections=detections ) annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) annotated_image = Image.fromarray(annotated_image) return annotated_image def process_image(input_image, input_text, max_new_tokens): cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) result = paligemma_detection(input_image, input_text, max_new_tokens) annotated_image = annotate_image(result, (input_image.width, input_image.height), input_text, cv_image) return annotated_image, result @spaces.GPU def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(track_tqdm=True)): if not input_video: gr.Info("Please upload a video.") return None if not input_text: gr.Info("Please enter a text prompt.") return None class_names = parse_class_names(input_text) if not class_names: gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") return None, None name = generate_unique_name() frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) create_directory(frame_directory_path) video_info = sv.VideoInfo.from_video_path(input_video) frame_generator = sv.get_video_frames_generator(input_video) video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") results = [] with sv.VideoSink(video_path, video_info=video_info) as sink: for frame in progress.tqdm(frame_generator, desc="Processing video"): pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) model_inputs = processor( text=input_text, images=pil_frame, return_tensors="pt" ).to(torch.bfloat16).to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) generation = generation[0][input_len:] result = processor.decode(generation, skip_special_tokens=True) detections = sv.Detections.from_lmm( sv.LMM.PALIGEMMA, result, resolution_wh=(video_info.width, video_info.height), classes=class_names ) annotated_frame = BOX_ANNOTATOR.annotate( scene=frame.copy(), detections=detections ) annotated_frame = LABEL_ANNOTATOR.annotate( scene=annotated_frame, detections=detections ) annotated_frame = MASK_ANNOTATOR.annotate( scene=annotated_frame, detections=detections ) results.append(result) sink.write_frame(annotated_frame) delete_directory(frame_directory_path) return video_path, results with gr.Blocks() as app: gr.Markdown(INTRO_TEXT) with gr.Tab("Image Detection"): with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") input_text = gr.Textbox( lines=2, placeholder="Enter prompt in format like this: detect person;dog;building", label="Enter detection prompt" ) max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.") with gr.Column(): annotated_image = gr.Image(type="pil", label="Annotated Image") detection_result = gr.Textbox(label="Detection Result") gr.Button("Submit").click( fn=process_image, inputs=[input_image, input_text, max_new_tokens], outputs=[annotated_image, detection_result] ) with gr.Tab("Video Detection"): with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video") input_text = gr.Textbox( lines=2, placeholder="Enter prompt in format like this: detect person;dog;building", label="Enter detection prompt" ) max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.") with gr.Column(): output_video = gr.Video(label="Annotated Video") detection_result = gr.Textbox(label="Detection Result") gr.Button("Process Video").click( fn=process_video, inputs=[input_video, input_text, max_new_tokens], outputs=[output_video, detection_result] ) if __name__ == "__main__": app.launch(ssr_mode=False)