""" Makes the entire set of text emebeddings for all possible names in the tree of life. Uses the catalog.csv file from TreeOfLife-10M. """ import argparse import csv import json import os import logging import numpy as np import torch import torch.nn.functional as F from open_clip import create_model, get_tokenizer from tqdm import tqdm import lib from templates import openai_imagenet_template log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger() model_str = "hf-hub:imageomics/bioclip" tokenizer_str = "ViT-B-16" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") @torch.no_grad() def write_txt_features(name_lookup): if os.path.isfile(args.out_path): all_features = np.load(args.out_path) else: all_features = np.zeros((512, len(name_lookup)), dtype=np.float32) batch_size = args.batch_size // len(openai_imagenet_template) for batch, (names, indices) in enumerate( tqdm( lib.batched(name_lookup.values(), batch_size), desc="txt feats", total=len(name_lookup) // batch_size, ) ): # Skip if any non-zero elements if all_features[:, indices].any(): logger.info(f"Skipping batch {batch}") continue txts = [ template(name) for name in names for template in openai_imagenet_template ] txts = tokenizer(txts).to(device) txt_features = model.encode_text(txts) txt_features = torch.reshape( txt_features, (len(names), len(openai_imagenet_template), 512) ) txt_features = F.normalize(txt_features, dim=2).mean(dim=1) txt_features /= txt_features.norm(dim=1, keepdim=True) all_features[:, indices] = txt_features.T.cpu().numpy() if batch % 100 == 0: np.save(args.out_path, all_features) np.save(args.out_path, all_features) def convert_txt_features_to_avgs(name_lookup): assert os.path.isfile(args.out_path) # Put that big boy on the GPU. We're going fast. all_features = torch.from_numpy(np.load(args.out_path)).to(device) logger.info("Loaded text features from disk to %s.", device) all_names = [set() for rank in ranks] for name, index in tqdm(name_lookup.values()): i = len(name) - 1 all_names[i].add((name, index)) zeroed = 0 for i, rank in reversed(list(enumerate(ranks))): if rank == "Species": continue for name, index in tqdm(all_names[i], desc=rank): species = tuple(zip(*((d, i) for d, i in name_lookup.descendants(prefix=name) if len(d) >= 6))) if not species: logger.warning("No species for %s.", " ".join(name)) all_features[:, index] = 0.0 zeroed += 1 continue values, indices = species mean = all_features[:, indices].mean(dim=1) all_features[:, index] = F.normalize(mean, dim=0) out_path, ext = os.path.splitext(args.out_path) np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy()) if zeroed: logger.warning("Zeroed out %d nodes because they didn't have any genus or species-level labels.", zeroed) def get_name_lookup(catalog_path, cache_path): if os.path.isfile(cache_path): with open(cache_path) as fd: lookup = lib.TaxonomicTree.from_dict(json.load(fd)) return lookup lookup = lib.TaxonomicTree() with open(catalog_path) as fd: reader = csv.DictReader(fd) for row in tqdm(reader, desc="catalog"): name = [ row["kingdom"], row["phylum"], row["class"], row["order"], row["family"], row["genus"], row["species"], ] if any(not value for value in name): name = name[: name.index("")] lookup.add(name) with open(args.name_cache_path, "w") as fd: json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder) return lookup if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--catalog-path", help="Path to the catalog.csv file from TreeOfLife-10M.", required=True, ) parser.add_argument("--out-path", help="Path to the output file.", required=True) parser.add_argument( "--name-cache-path", help="Path to the name cache file.", default="name_lookup.json", ) parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int) args = parser.parse_args() name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path) logger.info("Got name lookup.") model = create_model(model_str, output_dict=True, require_pretrained=True) model = model.to(device) logger.info("Created model.") model = torch.compile(model) logger.info("Compiled model.") tokenizer = get_tokenizer(tokenizer_str) write_txt_features(name_lookup) convert_txt_features_to_avgs(name_lookup)