Spaces:
Configuration error
Configuration error
File size: 2,905 Bytes
705d528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""
Makes the entire set of text emebeddings for all possible names in the tree of life.
Uses the catalog.csv file from TreeOfLife-10M.
"""
import argparse
import csv
import json
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
import lib
from templates import openai_imagenet_template
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def write_txt_features(name_lookup):
all_features = np.memmap(
args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size)
)
batch_size = args.batch_size // len(openai_imagenet_template)
for names, indices in tqdm(lib.batched(name_lookup, batch_size)):
txts = [template(name) for name in names for template in openai_imagenet_template]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512))
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, indices] = txt_features.cpu().numpy().T
all_features.flush()
def get_name_lookup(catalog_path):
lookup = lib.TaxonomicTree()
with open(catalog_path) as fd:
reader = csv.DictReader(fd)
for row in tqdm(reader):
name = [
row["kingdom"],
row["phylum"],
row["class"],
row["order"],
row["family"],
row["genus"],
row["species"],
]
if any(not value for value in name):
name = name[: name.index("")]
lookup.add(name)
return lookup
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--catalog-path",
help="Path to the catalog.csv file from TreeOfLife-10M.",
required=True,
)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json")
parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int)
args = parser.parse_args()
name_lookup = get_name_lookup(args.catalog_path)
with open(args.name_cache_path, "w") as fd:
json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
print("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
print("Created model.")
model = torch.compile(model)
print("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(name_lookup)
|