Spaces:
Running
on
A10G
Running
on
A10G
from functools import partial | |
from itertools import islice | |
from typing import Callable, List, Optional, Sequence, Union | |
import torch | |
import torch.nn.functional as F | |
def batched(iterable, n): | |
"""Batch data into lists of length *n*. The last batch may be shorter. | |
NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl | |
""" | |
it = iter(iterable) | |
while True: | |
batch = list(islice(it, n)) | |
if not batch: | |
break | |
yield batch | |
def build_zero_shot_classifier( | |
model, | |
tokenizer, | |
classnames: Sequence[str], | |
templates: Sequence[Union[Callable, str]], | |
num_classes_per_batch: Optional[int] = 10, | |
device: Union[str, torch.device] = 'cpu', | |
use_tqdm: bool = False, | |
): | |
""" Build zero-shot classifier weights by iterating over class names in batches | |
Args: | |
model: CLIP model instance | |
tokenizer: CLIP tokenizer instance | |
classnames: A sequence of class (label) names | |
templates: A sequence of callables or format() friendly strings to produce templates per class name | |
num_classes_per_batch: The number of classes to batch together in each forward, all if None | |
device: Device to use. | |
use_tqdm: Enable TQDM progress bar. | |
""" | |
assert isinstance(templates, Sequence) and len(templates) > 0 | |
assert isinstance(classnames, Sequence) and len(classnames) > 0 | |
use_format = isinstance(templates[0], str) | |
num_templates = len(templates) | |
num_classes = len(classnames) | |
if use_tqdm: | |
import tqdm | |
num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) | |
iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) | |
else: | |
iter_wrap = iter | |
def _process_batch(batch_classnames): | |
num_batch_classes = len(batch_classnames) | |
texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] | |
texts = tokenizer(texts).to(device) | |
class_embeddings = F.normalize(model.encode_text(texts), dim=-1) | |
class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) | |
class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) | |
class_embeddings = class_embeddings.T | |
return class_embeddings | |
with torch.no_grad(): | |
if num_classes_per_batch: | |
batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] | |
zeroshot_weights = torch.cat(batched_embeds, dim=1) | |
else: | |
zeroshot_weights = _process_batch(classnames) | |
return zeroshot_weights | |
def build_zero_shot_classifier_legacy( | |
model, | |
tokenizer, | |
classnames: Sequence[str], | |
templates: Sequence[Union[Callable, str]], | |
device: Union[str, torch.device] = 'cpu', | |
use_tqdm: bool = False, | |
): | |
""" Build zero-shot classifier weights by iterating over class names 1 by 1 | |
Args: | |
model: CLIP model instance | |
tokenizer: CLIP tokenizer instance | |
classnames: A sequence of class (label) names | |
templates: A sequence of callables or format() friendly strings to produce templates per class name | |
device: Device to use. | |
use_tqdm: Enable TQDM progress bar. | |
""" | |
assert isinstance(templates, Sequence) and len(templates) > 0 | |
assert isinstance(classnames, Sequence) and len(classnames) > 0 | |
if use_tqdm: | |
import tqdm | |
iter_wrap = tqdm.tqdm | |
else: | |
iter_wrap = iter | |
use_format = isinstance(templates[0], str) | |
with torch.no_grad(): | |
zeroshot_weights = [] | |
for classname in iter_wrap(classnames): | |
texts = [template.format(classname) if use_format else template(classname) for template in templates] | |
texts = tokenizer(texts).to(device) # tokenize | |
class_embeddings = model.encode_text(texts) | |
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) | |
class_embedding /= class_embedding.norm() | |
zeroshot_weights.append(class_embedding) | |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) | |
return zeroshot_weights | |