Spaces:
Running
on
Zero
Running
on
Zero
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import supervision as sv | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
from helpers.utils import create_directory, delete_directory, generate_unique_name | |
import os | |
BOX_ANNOTATOR = sv.BoxAnnotator() | |
LABEL_ANNOTATOR = sv.LabelAnnotator() | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
VIDEO_TARGET_DIRECTORY = "tmp" | |
create_directory(directory_path=VIDEO_TARGET_DIRECTORY) | |
model_id = "google/paligemma2-3b-pt-448" | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE) | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
def paligemma_detection(input_image, input_text): | |
model_inputs = processor(text=input_text, | |
images=input_image, | |
return_tensors="pt" | |
).to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) | |
generation = generation[0][input_len:] | |
result = processor.decode(generation, skip_special_tokens=True) | |
return result | |
def annotate_image(result, resolution_wh, class_names, cv_image): | |
detections = sv.Detections.from_lmm( | |
sv.LMM.PALIGEMMA, | |
result, | |
resolution_wh=resolution_wh, | |
classes=class_names.split(',') | |
) | |
annotated_image = BOX_ANNOTATOR.annotate( | |
scene=cv_image.copy(), | |
detections=detections | |
) | |
annotated_image = LABEL_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = MASK_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
annotated_image = Image.fromarray(annotated_image) | |
return annotated_image | |
def process_image(input_image,input_text,class_names): | |
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
result = paligemma_detection(input_image, input_text) | |
annotated_image = annotate_image(result, | |
(input_image.width, input_image.height), | |
class_names, cv_image) | |
return annotated_image, result | |
def process_video(input_video, input_text, class_names, progress=gr.Progress(track_tqdm=True)): | |
if not input_video: | |
gr.Info("Please upload a video.") | |
return None | |
if not input_text: | |
gr.Info("Please enter a text prompt.") | |
return None | |
name = generate_unique_name() | |
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) | |
create_directory(frame_directory_path) | |
video_info = sv.VideoInfo.from_video_path(input_video) | |
frame_generator = sv.get_video_frames_generator(input_video) | |
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") | |
results = [] | |
with sv.VideoSink(video_path, video_info=video_info) as sink: | |
for frame in progress.tqdm(frame_generator, desc="Processing video"): | |
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
model_inputs = processor( | |
text=input_text, | |
images=pil_frame, | |
return_tensors="pt" | |
).to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) | |
generation = generation[0][input_len:] | |
result = processor.decode(generation, skip_special_tokens=True) | |
detections = sv.Detections.from_lmm( | |
sv.LMM.PALIGEMMA, | |
result, | |
resolution_wh=(video_info.width, video_info.height), | |
classes=class_names.split(',') | |
) | |
annotated_frame = BOX_ANNOTATOR.annotate( | |
scene=frame.copy(), | |
detections=detections | |
) | |
annotated_frame = LABEL_ANNOTATOR.annotate( | |
scene=annotated_frame, | |
detections=detections | |
) | |
annotated_frame = MASK_ANNOTATOR.annotate( | |
scene=annotated_frame, | |
detections=detections | |
) | |
results.append(result) | |
sink.write_frame(annotated_frame) | |
delete_directory(frame_directory_path) | |
return video_path, results | |
with gr.Blocks() as app: | |
gr.Markdown( """ | |
## PaliGemma 2 Detection with Supervision - Demo | |
<br> | |
<div style="display: flex; gap: 10px;"> | |
<a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md"> | |
<img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github"> | |
</a> | |
<a href="https://huggingface.co/blog/paligemma"> | |
<img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface"> | |
</a> | |
<a href="https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab"> | |
</a> | |
<a href="https://arxiv.org/abs/2412.03555"> | |
<img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper"> | |
</a> | |
<a href="https://supervision.roboflow.com/"> | |
<img src="https://img.shields.io/badge/Supervision-6706CE?style=flat&logo=Roboflow&logoColor=white" alt="Supervision"> | |
</a> | |
</div> | |
<br> | |
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and | |
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) | |
vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile | |
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question | |
answering, text reading, object detection and object segmentation. | |
This space show how to use PaliGemma 2 for object detection with supervision. | |
You can input an image and a text prompt | |
""") | |
with gr.Tab("Image Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog") | |
class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names") | |
with gr.Column(): | |
annotated_image = gr.Image(type="pil", label="Annotated Image") | |
detection_result = gr.Textbox(label="Detection Result") | |
gr.Button("Submit").click( | |
fn=process_image, | |
inputs=[input_image, input_text, class_names], | |
outputs=[annotated_image, detection_result] | |
) | |
with gr.Tab("Video Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video(label="Input Video") | |
input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog") | |
class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names") | |
with gr.Column(): | |
output_video = gr.Video(label="Annotated Video") | |
detection_result = gr.Textbox(label="Detection Result") | |
gr.Button("Process Video").click( | |
fn=process_video, | |
inputs=[input_video, input_text, class_names], | |
outputs=[output_video, detection_result] | |
) | |
if __name__ == "__main__": | |
app.launch() |