object-detection-safari / classifier.py
mpsk's picture
Update classifier.py
0fc1cb9
import torch
def extract_text_feature(prompt, model, processor, device="cpu"):
"""Extract text features
Args:
prompt: a single text query
model: OwlViT model
processor: OwlViT processor
device (str, optional): device to run. Defaults to 'cpu'.
"""
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
with torch.no_grad():
input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device)
print(input_ids.device)
text_outputs = model.owlvit.text_model(
input_ids=input_ids,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
)
text_embeds = text_outputs[1]
text_embeds = model.owlvit.text_projection(text_embeds)
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
query_embeds = text_embeds
return input_ids, query_embeds
def prompt2vec(prompt: str, model, processor):
"""Convert prompt into a computational vector
Args:
prompt (str): Text to be tokenized
Returns:
xq: vector from the tokenizer, representing the original prompt
"""
# inputs = tokenizer(prompt, return_tensors='pt')
# out = clip.get_text_features(**inputs)
input_ids, xq = extract_text_feature(prompt, model, processor)
input_ids = input_ids.detach().cpu().numpy()
xq = xq.detach().cpu().numpy()
return input_ids, xq
def tune(clf, X, y, iters=2):
"""Train the Zero-shot Classifier
Args:
X (numpy.ndarray): Input vectors (retreived vectors)
y (list of floats or numpy.ndarray): Scores given by user
iters (int, optional): iterations of updates to be run
"""
assert len(X) == len(y)
# train the classifier
clf.fit(X, y, iters=iters)
# extract new vector
return clf.get_weights()
class Classifier:
"""Multi-Class Zero-shot Classifier
This Classifier provides proxy regarding to the user's reaction to the probed images.
The proxy will replace the original query vector generated by prompted vector and finally
give the user a satisfying retrieval result.
This can be commonly seen in a recommendation system. The classifier will recommend more
precise result as it accumulating user's activity.
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
and the last one takes the negative one.
"""
def __init__(self, client, obj_db:str, xq: list):
init_weight = torch.Tensor(xq)
self.num_class = xq.shape[0]
self.DIMS = xq.shape[1]
# convert initial query `xq` to tensor parameter to init weights
self.weight = init_weight
self.client = client
self.obj_db = obj_db
def fit(self, X: list, y: list, iters: int = 5):
# convert X and y to tensor
xq_s = [
f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]"
for _xq in self.get_weights().tolist()
]
for _ in range(iters):
# zero gradients
grad = []
# Normalize the weight before inference
# This will constrain the gradient or you will have an explosion on query vector
self.weight /= torch.norm(
self.weight, p=2, dim=-1, keepdim=True
)
for n in range(self.num_class):
# select all training sample and create labels
labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]])))
# NOTE from @fangruil
# Use SQL to calculate the gradient
# For binary cross entropy we have
# g = (1/(1+\exp(-XW))-Y)^TX
# To simplify the query, we separated
# the calculation into class numbers
grad_q_str = f"""
SELECT avgForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
FROM (
SELECT groupArray(arrayPopBack(prelogit)) AS X,
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
FROM {self.obj_db} WHERE obj_id IN {objs})"""
grad_ = [r['grad'] for r in self.client.query(grad_q_str).named_results()][0]
grad.append(torch.as_tensor(grad_))
# update weights
grad = torch.stack(grad, dim=0)
self.weight -= 0.01 * grad
self.weight /= torch.norm(
self.weight, p=2, dim=-1, keepdim=True
)
def get_weights(self):
xq = self.weight.detach().numpy()
return xq
class SplitLayer(torch.nn.Module):
def forward(self, x):
return torch.split(x, 1, dim=-1)