|
import random |
|
import panel as pn |
|
import requests |
|
from PIL import Image |
|
|
|
from transformers import CLIPProcessor, CLIPModel |
|
from typing import List, Tuple |
|
|
|
|
|
def set_random_url(_): |
|
pet = random.choice(["cat", "dog"]) |
|
api_url = f"https://api.the{pet}api.com/v1/images/search" |
|
with requests.get(api_url) as resp: |
|
resp.raise_for_status() |
|
url = resp.json()[0]["url"] |
|
image_url.value = url |
|
|
|
|
|
@pn.cache |
|
def load_processor_model( |
|
processor_name: str, model_name: str |
|
) -> Tuple[CLIPProcessor, CLIPModel]: |
|
processor = CLIPProcessor.from_pretrained(processor_name) |
|
model = CLIPModel.from_pretrained(model_name) |
|
return processor, model |
|
|
|
|
|
@pn.cache |
|
def open_image_url(image_url: str) -> Image: |
|
with requests.get(image_url, stream=True) as resp: |
|
resp.raise_for_status() |
|
image = Image.open(resp.raw) |
|
return image |
|
|
|
|
|
def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: |
|
processor, model = load_processor_model( |
|
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" |
|
) |
|
inputs = processor( |
|
text=class_items, |
|
images=[image], |
|
return_tensors="pt", |
|
) |
|
outputs = model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() |
|
return class_likelihoods[0] |
|
|
|
|
|
def process_inputs(class_names: List[str], image_url: str): |
|
""" |
|
High level function that takes in the user inputs and returns the |
|
classification results as panel objects. |
|
""" |
|
image = open_image_url(image_url) |
|
class_items = class_names.split(",") |
|
class_likelihoods = get_similarity_scores(class_items, image) |
|
|
|
|
|
results_column = pn.Column("## π Here are the results!") |
|
|
|
results_column.append( |
|
pn.pane.Image(image, max_width=698, sizing_mode="scale_width") |
|
) |
|
|
|
for class_item, class_likelihood in zip(class_items, class_likelihoods): |
|
row_label = pn.widgets.StaticText( |
|
name=class_item.strip(), value=f"{class_likelihood:.2%}", margin=(0, 10) |
|
) |
|
row_bar = pn.indicators.Progress( |
|
max=100, |
|
value=int(class_likelihood * 100), |
|
sizing_mode="stretch_width", |
|
bar_color="secondary", |
|
margin=(0, 10), |
|
) |
|
row_column = pn.Column(row_label, row_bar) |
|
results_column.append(row_column) |
|
return results_column |
|
|
|
|
|
randomize_url = pn.widgets.Button(name="Randomize URL", align="end") |
|
|
|
image_url = pn.widgets.TextInput( |
|
name="Image URL to classify", |
|
value="https://cdn2.thecatapi.com/images/cct.jpg", |
|
) |
|
class_names = pn.widgets.TextInput( |
|
name="Comma separated class names", |
|
placeholder="Enter possible class names, e.g. cat, dog", |
|
value="cat, dog, parrot", |
|
) |
|
|
|
input_widgets = pn.Column( |
|
"## π Click randomize or paste a URL to start classifying!", |
|
pn.Row(image_url, randomize_url), |
|
class_names, |
|
) |
|
|
|
|
|
randomize_url.on_click(set_random_url) |
|
interactive_result = pn.panel( |
|
pn.bind( |
|
process_inputs, image_url=image_url, class_names=class_names |
|
), loading_indicator=True |
|
) |
|
|
|
|
|
main = pn.WidgetBox( |
|
input_widgets, |
|
interactive_result, |
|
) |
|
|
|
pn.template.BootstrapTemplate( |
|
title="Panel Image Classification Demo", |
|
main=main, |
|
main_max_width="min(50%, 698px)", |
|
header_background="#F08080", |
|
).servable(title="Panel Image Classification Demo") |