sophiamyang's picture
add loading indicators
33701df
import random
import panel as pn
import requests
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from typing import List, Tuple
def set_random_url(_):
pet = random.choice(["cat", "dog"])
api_url = f"https://api.the{pet}api.com/v1/images/search"
with requests.get(api_url) as resp:
resp.raise_for_status()
url = resp.json()[0]["url"]
image_url.value = url
@pn.cache
def load_processor_model(
processor_name: str, model_name: str
) -> Tuple[CLIPProcessor, CLIPModel]:
processor = CLIPProcessor.from_pretrained(processor_name)
model = CLIPModel.from_pretrained(model_name)
return processor, model
@pn.cache
def open_image_url(image_url: str) -> Image:
with requests.get(image_url, stream=True) as resp:
resp.raise_for_status()
image = Image.open(resp.raw)
return image
def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
processor, model = load_processor_model(
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
)
inputs = processor(
text=class_items,
images=[image],
return_tensors="pt", # pytorch tensors
)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
return class_likelihoods[0]
def process_inputs(class_names: List[str], image_url: str):
"""
High level function that takes in the user inputs and returns the
classification results as panel objects.
"""
image = open_image_url(image_url)
class_items = class_names.split(",")
class_likelihoods = get_similarity_scores(class_items, image)
# build the results column
results_column = pn.Column("## πŸŽ‰ Here are the results!")
results_column.append(
pn.pane.Image(image, max_width=698, sizing_mode="scale_width")
)
for class_item, class_likelihood in zip(class_items, class_likelihoods):
row_label = pn.widgets.StaticText(
name=class_item.strip(), value=f"{class_likelihood:.2%}", margin=(0, 10)
)
row_bar = pn.indicators.Progress(
max=100,
value=int(class_likelihood * 100),
sizing_mode="stretch_width",
bar_color="secondary",
margin=(0, 10),
)
row_column = pn.Column(row_label, row_bar)
results_column.append(row_column)
return results_column
# create widgets
randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
image_url = pn.widgets.TextInput(
name="Image URL to classify",
value="https://cdn2.thecatapi.com/images/cct.jpg",
)
class_names = pn.widgets.TextInput(
name="Comma separated class names",
placeholder="Enter possible class names, e.g. cat, dog",
value="cat, dog, parrot",
)
input_widgets = pn.Column(
"## 😊 Click randomize or paste a URL to start classifying!",
pn.Row(image_url, randomize_url),
class_names,
)
# add interactivity
randomize_url.on_click(set_random_url)
interactive_result = pn.panel(
pn.bind(
process_inputs, image_url=image_url, class_names=class_names
), loading_indicator=True
)
# create dashboard
main = pn.WidgetBox(
input_widgets,
interactive_result,
)
pn.template.BootstrapTemplate(
title="Panel Image Classification Demo",
main=main,
main_max_width="min(50%, 698px)",
header_background="#F08080",
).servable(title="Panel Image Classification Demo")