Spaces:
Runtime error
Runtime error
""" | |
Makes the entire set of text emebeddings for all possible names in the tree of life. | |
Uses the catalog.csv file from TreeOfLife-10M. | |
""" | |
import argparse | |
import csv | |
import json | |
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from open_clip import create_model, get_tokenizer | |
from tqdm import tqdm | |
import lib | |
from templates import openai_imagenet_template | |
model_str = "hf-hub:imageomics/bioclip" | |
tokenizer_str = "ViT-B-16" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def write_txt_features(name_lookup): | |
if os.path.isfile(args.out_path): | |
all_features = np.load(args.out_path) | |
else: | |
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32) | |
batch_size = args.batch_size // len(openai_imagenet_template) | |
for batch, (names, indices) in enumerate( | |
tqdm( | |
lib.batched(name_lookup, batch_size), | |
desc="txt feats", | |
total=len(name_lookup) // batch_size, | |
) | |
): | |
# Skip if any non-zero elements | |
if all_features[:, indices].any(): | |
print(f"Skipping batch {batch}") | |
continue | |
txts = [ | |
template(name) for name in names for template in openai_imagenet_template | |
] | |
txts = tokenizer(txts).to(device) | |
txt_features = model.encode_text(txts) | |
txt_features = torch.reshape( | |
txt_features, (len(names), len(openai_imagenet_template), 512) | |
) | |
txt_features = F.normalize(txt_features, dim=2).mean(dim=1) | |
txt_features /= txt_features.norm(dim=1, keepdim=True) | |
all_features[:, indices] = txt_features.T.cpu().numpy() | |
if batch % 100 == 0: | |
np.save(args.out_path, all_features) | |
np.save(args.out_path, all_features) | |
def get_name_lookup(catalog_path, cache_path): | |
if os.path.isfile(cache_path): | |
with open(cache_path) as fd: | |
lookup = lib.TaxonomicTree.from_dict(json.load(fd)) | |
return lookup | |
lookup = lib.TaxonomicTree() | |
with open(catalog_path) as fd: | |
reader = csv.DictReader(fd) | |
for row in tqdm(reader, desc="catalog"): | |
name = [ | |
row["kingdom"], | |
row["phylum"], | |
row["class"], | |
row["order"], | |
row["family"], | |
row["genus"], | |
row["species"], | |
] | |
if any(not value for value in name): | |
name = name[: name.index("")] | |
lookup.add(name) | |
with open(args.name_cache_path, "w") as fd: | |
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder) | |
return lookup | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--catalog-path", | |
help="Path to the catalog.csv file from TreeOfLife-10M.", | |
required=True, | |
) | |
parser.add_argument("--out-path", help="Path to the output file.", required=True) | |
parser.add_argument( | |
"--name-cache-path", | |
help="Path to the name cache file.", | |
default="name_lookup.json", | |
) | |
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int) | |
args = parser.parse_args() | |
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path) | |
print("Got name lookup.") | |
model = create_model(model_str, output_dict=True, require_pretrained=True) | |
model = model.to(device) | |
print("Created model.") | |
model = torch.compile(model) | |
print("Compiled model.") | |
tokenizer = get_tokenizer(tokenizer_str) | |
write_txt_features(name_lookup) | |