Spaces:
Sleeping
Sleeping
import gradio as gr | |
import PIL.Image | |
import transformers | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import os | |
import string | |
import functools | |
import re | |
import numpy as np | |
import spaces | |
from PIL import Image, ImageDraw | |
import re | |
model_id = "mattraj/curacel-autodamage-1" | |
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).eval().to(device) | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
###### Transformers Inference | |
def infer( | |
image: PIL.Image.Image, | |
text: str, | |
max_new_tokens: int = 2048 | |
) -> tuple: | |
inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_length=max_new_tokens | |
) | |
result = processor.decode(generated_ids[0], skip_special_tokens=True) | |
# Placeholder to extract bounding box info from the result (you should replace this with actual bounding box extraction) | |
bounding_boxes = extract_bounding_boxes(result, image) | |
# Draw bounding boxes on the image | |
annotated_image = image.copy() | |
draw = ImageDraw.Draw(annotated_image) | |
# Example of drawing bounding boxes (replace with actual coordinates) | |
for idx, (box, label) in enumerate(bounding_boxes): | |
color = COLORS[idx % len(COLORS)] | |
draw.rectangle(box, outline=color, width=3) | |
draw.text((box[0], box[1]), label, fill=color) | |
return result, annotated_image | |
def extract_bounding_boxes(result, image): | |
""" | |
Extract bounding boxes and labels from the model result. | |
Coordinates are scaled by dividing by 1024 and then multiplying by the image dimensions. | |
Args: | |
result (str): The model's output string containing bounding box data. | |
image (PIL.Image.Image): The image to use for scaling the bounding boxes. | |
Returns: | |
List[Tuple[Tuple[int, int, int, int], str]]: A list of bounding boxes and labels. | |
""" | |
# Regular expression to capture the <loc> tags and their associated labels | |
loc_pattern = re.compile(r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s*([a-zA-Z\-]+)") | |
# Get image dimensions | |
width, height = image.size | |
# Find all matches of bounding box coordinates and labels in the result string | |
matches = loc_pattern.findall(result) | |
bounding_boxes = [] | |
for match in matches: | |
x1, y1, x2, y2, label = match | |
# Convert coordinates from string to integer | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
# Scale coordinates | |
x1 = int((x1 / 1024) * width) | |
y1 = int((y1 / 1024) * height) | |
x2 = int((x2 / 1024) * width) | |
y2 = int((y2 / 1024) * height) | |
# Append the scaled bounding box and label as a tuple | |
bounding_boxes.append(((x1, y1, x2, y2), label)) | |
return bounding_boxes | |
######## Demo | |
INTRO_TEXT = """## Curacel Auto Damage demo\n\n | |
Finetuned from: google/paligemma-3b-pt-448 | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(INTRO_TEXT) | |
with gr.Tab("Text Generation"): | |
with gr.Column(): | |
image = gr.Image(type="pil") | |
text_input = gr.Text(label="Input Text") | |
text_output = gr.Text(label="Text Output") | |
output_image = gr.Image(label="Annotated Image") | |
chat_btn = gr.Button() | |
chat_inputs = [image, text_input] | |
chat_outputs = [text_output, output_image] | |
chat_btn.click( | |
fn=infer, | |
inputs=chat_inputs, | |
outputs=chat_outputs, | |
) | |
examples = [["./car-1.png", "detect Front-Windscreen-Damage ; Headlight-Damage ; Major-Rear-Bumper-Dent ; Rear-windscreen-Damage ; RunningBoard-Dent ; Sidemirror-Damage ; Signlight-Damage ; Taillight-Damage ; bonnet-dent ; doorouter-dent ; doorouter-scratch ; fender-dent ; front-bumper-dent ; front-bumper-scratch ; medium-Bodypanel-Dent ; paint-chip ; paint-trace ; pillar-dent ; quaterpanel-dent ; rear-bumper-dent ; rear-bumper-scratch ; roof-dent"]] | |
gr.Markdown("") | |
gr.Examples( | |
examples=examples, | |
inputs=chat_inputs, | |
) | |
######### | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |