|
""" |
|
Do zero-shot image classification. |
|
|
|
Writes the output to a plaintext and JSON format in the logs directory. |
|
""" |
|
import argparse |
|
import ast |
|
import contextlib |
|
import json |
|
import logging |
|
import os |
|
import random |
|
import sys |
|
|
|
import numpy as np |
|
import open_clip |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import datasets |
|
from tqdm import tqdm |
|
|
|
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" |
|
logging.basicConfig(level=logging.INFO, format=log_format) |
|
logger = logging.getLogger("main") |
|
|
|
openai_templates = [ |
|
lambda c: f"a bad photo of a {c}.", |
|
lambda c: f"a photo of many {c}.", |
|
lambda c: f"a sculpture of a {c}.", |
|
lambda c: f"a photo of the hard to see {c}.", |
|
lambda c: f"a low resolution photo of the {c}.", |
|
lambda c: f"a rendering of a {c}.", |
|
lambda c: f"graffiti of a {c}.", |
|
lambda c: f"a bad photo of the {c}.", |
|
lambda c: f"a cropped photo of the {c}.", |
|
lambda c: f"a tattoo of a {c}.", |
|
lambda c: f"the embroidered {c}.", |
|
lambda c: f"a photo of a hard to see {c}.", |
|
lambda c: f"a bright photo of a {c}.", |
|
lambda c: f"a photo of a clean {c}.", |
|
lambda c: f"a photo of a dirty {c}.", |
|
lambda c: f"a dark photo of the {c}.", |
|
lambda c: f"a drawing of a {c}.", |
|
lambda c: f"a photo of my {c}.", |
|
lambda c: f"the plastic {c}.", |
|
lambda c: f"a photo of the cool {c}.", |
|
lambda c: f"a close-up photo of a {c}.", |
|
lambda c: f"a black and white photo of the {c}.", |
|
lambda c: f"a painting of the {c}.", |
|
lambda c: f"a painting of a {c}.", |
|
lambda c: f"a pixelated photo of the {c}.", |
|
lambda c: f"a sculpture of the {c}.", |
|
lambda c: f"a bright photo of the {c}.", |
|
lambda c: f"a cropped photo of a {c}.", |
|
lambda c: f"a plastic {c}.", |
|
lambda c: f"a photo of the dirty {c}.", |
|
lambda c: f"a jpeg corrupted photo of a {c}.", |
|
lambda c: f"a blurry photo of the {c}.", |
|
lambda c: f"a photo of the {c}.", |
|
lambda c: f"a good photo of the {c}.", |
|
lambda c: f"a rendering of the {c}.", |
|
lambda c: f"a {c} in a video game.", |
|
lambda c: f"a photo of one {c}.", |
|
lambda c: f"a doodle of a {c}.", |
|
lambda c: f"a close-up photo of the {c}.", |
|
lambda c: f"a photo of a {c}.", |
|
lambda c: f"the origami {c}.", |
|
lambda c: f"the {c} in a video game.", |
|
lambda c: f"a sketch of a {c}.", |
|
lambda c: f"a doodle of the {c}.", |
|
lambda c: f"a origami {c}.", |
|
lambda c: f"a low resolution photo of a {c}.", |
|
lambda c: f"the toy {c}.", |
|
lambda c: f"a rendition of the {c}.", |
|
lambda c: f"a photo of the clean {c}.", |
|
lambda c: f"a photo of a large {c}.", |
|
lambda c: f"a rendition of a {c}.", |
|
lambda c: f"a photo of a nice {c}.", |
|
lambda c: f"a photo of a weird {c}.", |
|
lambda c: f"a blurry photo of a {c}.", |
|
lambda c: f"a cartoon {c}.", |
|
lambda c: f"art of a {c}.", |
|
lambda c: f"a sketch of the {c}.", |
|
lambda c: f"a embroidered {c}.", |
|
lambda c: f"a pixelated photo of a {c}.", |
|
lambda c: f"itap of the {c}.", |
|
lambda c: f"a jpeg corrupted photo of the {c}.", |
|
lambda c: f"a good photo of a {c}.", |
|
lambda c: f"a plushie {c}.", |
|
lambda c: f"a photo of the nice {c}.", |
|
lambda c: f"a photo of the small {c}.", |
|
lambda c: f"a photo of the weird {c}.", |
|
lambda c: f"the cartoon {c}.", |
|
lambda c: f"art of the {c}.", |
|
lambda c: f"a drawing of the {c}.", |
|
lambda c: f"a photo of the large {c}.", |
|
lambda c: f"a black and white photo of a {c}.", |
|
lambda c: f"the plushie {c}.", |
|
lambda c: f"a dark photo of a {c}.", |
|
lambda c: f"itap of a {c}.", |
|
lambda c: f"graffiti of the {c}.", |
|
lambda c: f"a toy {c}.", |
|
lambda c: f"itap of my {c}.", |
|
lambda c: f"a photo of a cool {c}.", |
|
lambda c: f"a photo of a small {c}.", |
|
lambda c: f"a tattoo of the {c}.", |
|
] |
|
|
|
|
|
def parse_args(args): |
|
class ParseKwargs(argparse.Action): |
|
def __call__(self, parser, namespace, values, option_string=None): |
|
kw = {} |
|
for value in values: |
|
key, value = value.split("=") |
|
try: |
|
kw[key] = ast.literal_eval(value) |
|
except (ValueError, SyntaxError): |
|
|
|
kw[key] = str(value) |
|
setattr(namespace, self.dest, kw) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--datasets", |
|
type=str, |
|
default=None, |
|
nargs="+", |
|
help="Path to dirs(s) with validation data. In the format NAME=PATH.", |
|
action=ParseKwargs, |
|
) |
|
parser.add_argument( |
|
"--logs", type=str, default="./logs", help="Where to write logs" |
|
) |
|
parser.add_argument( |
|
"--exp", type=str, default="bioclip-zero-shot", help="Experiment name." |
|
) |
|
parser.add_argument( |
|
"--workers", type=int, default=8, help="Number of dataloader workers per GPU." |
|
) |
|
parser.add_argument( |
|
"--batch-size", type=int, default=64, help="Batch size per GPU." |
|
) |
|
parser.add_argument( |
|
"--precision", |
|
choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp32"], |
|
default="amp", |
|
help="Floating point precision.", |
|
) |
|
parser.add_argument("--seed", type=int, default=0, help="Default random seed.") |
|
args = parser.parse_args(args) |
|
os.makedirs(os.path.join(args.logs, args.exp), exist_ok=True) |
|
|
|
return args |
|
|
|
|
|
def make_txt_features(model, classnames, templates, args): |
|
tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip") |
|
with torch.no_grad(): |
|
txt_features = [] |
|
for classname in tqdm(classnames): |
|
classname = " ".join(word for word in classname.split("_") if word) |
|
texts = [template(classname) for template in templates] |
|
texts = tokenizer(texts).to(args.device) |
|
class_embeddings = model.encode_text(texts) |
|
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) |
|
class_embedding /= class_embedding.norm() |
|
txt_features.append(class_embedding) |
|
txt_features = torch.stack(txt_features, dim=1).to(args.device) |
|
return txt_features |
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
pred = output.topk(max(topk), 1, True, True)[1].t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
return [correct[:k].reshape(-1).float().sum(0, keepdim=True).item() for k in topk] |
|
|
|
|
|
def get_autocast(precision): |
|
if precision == "amp": |
|
return torch.cuda.amp.autocast |
|
elif precision == "amp_bfloat16" or precision == "amp_bf16": |
|
|
|
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) |
|
else: |
|
return contextlib.suppress |
|
|
|
|
|
def run(model, txt_features, dataloader, args): |
|
autocast = get_autocast(args.precision) |
|
cast_dtype = open_clip.get_cast_dtype(args.precision) |
|
|
|
top1, top5, n = 0.0, 0.0, 0.0 |
|
|
|
with torch.no_grad(): |
|
for images, targets in tqdm(dataloader, unit_scale=args.batch_size): |
|
images = images.to(args.device) |
|
if cast_dtype is not None: |
|
images = images.to(dtype=cast_dtype) |
|
targets = targets.to(args.device) |
|
|
|
with autocast(): |
|
image_features = model.encode_image(images) |
|
image_features = F.normalize(image_features, dim=-1) |
|
logits = model.logit_scale.exp() * image_features @ txt_features |
|
|
|
|
|
acc1, acc5 = accuracy(logits, targets, topk=(1, 5)) |
|
top1 += acc1 |
|
top5 += acc5 |
|
n += images.size(0) |
|
|
|
top1 = top1 / n |
|
top5 = top5 / n |
|
return top1, top5 |
|
|
|
|
|
def evaluate(model, data, args): |
|
results = {} |
|
|
|
logger.info("Starting zero-shot classification.") |
|
|
|
for split in data: |
|
logger.info("Building zero-shot %s classifier.", split) |
|
|
|
classnames = data[split].dataset.classes |
|
classnames = [name.replace("_", " ") for name in classnames] |
|
|
|
txt_features = make_txt_features(model, classnames, openai_templates, args) |
|
|
|
logger.info("Got text features.") |
|
top1, top5 = run(model, txt_features, data[split], args) |
|
|
|
logger.info("%s-top1: %.3f", split, top1 * 100) |
|
logger.info("%s-top5: %.3f", split, top5 * 100) |
|
|
|
results[f"{split}-top1"] = top1 * 100 |
|
results[f"{split}-top5"] = top5 * 100 |
|
|
|
logger.info("Finished zero-shot %s.", split) |
|
|
|
logger.info("Finished zero-shot classification.") |
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args(sys.argv[1:]) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
torch.cuda.set_device(device) |
|
else: |
|
device = "cpu" |
|
args.device = device |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
np.random.seed(args.seed) |
|
random.seed(args.seed) |
|
|
|
|
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms( |
|
"hf-hub:imageomics/bioclip" |
|
) |
|
|
|
|
|
params_file = os.path.join(args.logs, args.exp, "params.json") |
|
with open(params_file, "w") as fd: |
|
params = {name: getattr(args, name) for name in vars(args)} |
|
json.dump(params, fd, sort_keys=True, indent=4) |
|
|
|
|
|
data = {} |
|
for split, path in args.datasets.items(): |
|
data[split] = torch.utils.data.DataLoader( |
|
datasets.ImageFolder(path, transform=preprocess_val), |
|
batch_size=args.batch_size, |
|
num_workers=args.workers, |
|
sampler=None, |
|
shuffle=False, |
|
) |
|
|
|
model.eval() |
|
results = evaluate(model, data, args) |
|
|
|
results_file = os.path.join(args.logs, args.exp, "results.json") |
|
with open(results_file, "w") as fd: |
|
json.dump(results, fd, indent=4, sort_keys=True) |
|
|