Spaces:
Runtime error
Runtime error
File size: 3,637 Bytes
290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 2cfb891 290c238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
"""
Makes the entire set of text emebeddings for all possible names in the tree of life.
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import os
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
import lib
from templates import openai_imagenet_template
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def write_txt_features(name_lookup):
if os.path.isfile(args.out_path):
all_features = np.load(args.out_path)
else:
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
batch_size = args.batch_size // len(openai_imagenet_template)
for batch, (names, indices) in enumerate(
tqdm(
lib.batched(name_lookup, batch_size),
desc="txt feats",
total=len(name_lookup) // batch_size,
)
):
# Skip if any non-zero elements
if all_features[:, indices].any():
print(f"Skipping batch {batch}")
continue
txts = [
template(name) for name in names for template in openai_imagenet_template
]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(
txt_features, (len(names), len(openai_imagenet_template), 512)
)
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, indices] = txt_features.T.cpu().numpy()
if batch % 100 == 0:
np.save(args.out_path, all_features)
np.save(args.out_path, all_features)
def get_name_lookup(catalog_path, cache_path):
if os.path.isfile(cache_path):
with open(cache_path) as fd:
lookup = lib.TaxonomicTree.from_dict(json.load(fd))
return lookup
lookup = lib.TaxonomicTree()
with open(catalog_path) as fd:
reader = csv.DictReader(fd)
for row in tqdm(reader, desc="catalog"):
name = [
row["kingdom"],
row["phylum"],
row["class"],
row["order"],
row["family"],
row["genus"],
row["species"],
]
if any(not value for value in name):
name = name[: name.index("")]
lookup.add(name)
with open(args.name_cache_path, "w") as fd:
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
return lookup
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--catalog-path",
help="Path to the catalog.csv file from TreeOfLife-10M.",
required=True,
)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument(
"--name-cache-path",
help="Path to the name cache file.",
default="name_lookup.json",
)
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
args = parser.parse_args()
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
print("Got name lookup.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(name_lookup)
|