bioclip-demo / app.py
Samuel Stevens
wip: hierarchical prediction
705d528
raw
history blame
3.17 kB
import os
import gradio as gr
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms
from templates import openai_imagenet_template
hf_token = os.getenv("HF_TOKEN")
hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
@torch.no_grad()
def get_txt_features(classnames, templates):
all_features = []
for classname in classnames:
txts = [template(classname) for template in templates]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
txt_features /= txt_features.norm()
all_features.append(txt_features)
all_features = torch.stack(all_features, dim=1)
return all_features
@torch.no_grad()
def predict(img, classes: list[str]) -> dict[str, float]:
classes = [cls.strip() for cls in classes if cls.strip()]
txt_features = get_txt_features(classes, openai_imagenet_template)
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
probs = F.softmax(logits, dim=0).to("cpu").tolist()
return {cls: prob for cls, prob in zip(classes, probs)}
def hierarchical_predict(img) -> list[str]:
"""
Predicts from the top of the tree of life down to the species.
"""
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
breakpoint()
def run(img, cls_str: str) -> dict[str, float]:
breakpoint()
if cls_str:
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
return predict(img, classes)
else:
return hierarchical_predict(img)
if __name__ == "__main__":
print("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
demo = gr.Interface(
fn=run,
inputs=[
gr.Image(shape=(224, 224)),
gr.Textbox(
placeholder="dog\ncat\n...",
lines=3,
label="Classes",
show_label=True,
info="If empty, will predict from the entire tree of life.",
),
],
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
allow_flagging="manual",
flagging_options=["Incorrect", "Other"],
flagging_callback=hf_writer,
)
demo.launch()