Samuel Stevens commited on
Commit
9cbc5ff
1 Parent(s): a33c93d

fix bug. todo: merge .values and .descendants

Browse files
Files changed (1) hide show
  1. make_txt_embedding.py +2 -2
make_txt_embedding.py CHANGED
@@ -39,7 +39,7 @@ def write_txt_features(name_lookup):
39
  batch_size = args.batch_size // len(openai_imagenet_template)
40
  for batch, (names, indices) in enumerate(
41
  tqdm(
42
- lib.batched(name_lookup, batch_size),
43
  desc="txt feats",
44
  total=len(name_lookup) // batch_size,
45
  )
@@ -84,7 +84,7 @@ def convert_txt_features_to_avgs(name_lookup):
84
  if rank == "Species":
85
  continue
86
  for name, index in tqdm(all_names[i], desc=rank):
87
- species = tuple(zip(*((d, i) for d, i in name_lookup.descendants(prefix=name) if len(d) >= 7)))
88
  if not species:
89
  logger.warning("No species for %s.", " ".join(name))
90
  all_features[:, index] = 0.0
 
39
  batch_size = args.batch_size // len(openai_imagenet_template)
40
  for batch, (names, indices) in enumerate(
41
  tqdm(
42
+ lib.batched(name_lookup.values(), batch_size),
43
  desc="txt feats",
44
  total=len(name_lookup) // batch_size,
45
  )
 
84
  if rank == "Species":
85
  continue
86
  for name, index in tqdm(all_names[i], desc=rank):
87
+ species = tuple(zip(*((d, i) for d, i in name_lookup.descendants(prefix=name) if len(d) >= 6)))
88
  if not species:
89
  logger.warning("No species for %s.", " ".join(name))
90
  all_features[:, index] = 0.0