import torch def extract_text_feature(prompt, model, processor, device='cpu'): """Extract text features Args: prompt: a single text query model: OwlViT model processor: OwlViT processor device (str, optional): device to run. Defaults to 'cpu'. """ device = 'cpu' if torch.cuda.is_available(): device = 'cuda' with torch.no_grad(): input_ids = torch.as_tensor(processor(text=prompt)[ 'input_ids']).to(device) print(input_ids.device) text_outputs = model.owlvit.text_model( input_ids=input_ids, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ) text_embeds = text_outputs[1] text_embeds = model.owlvit.text_projection(text_embeds) text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 query_embeds = text_embeds return input_ids, query_embeds def prompt2vec(prompt: str, model, processor): """ Convert prompt into a computational vector Args: prompt (str): Text to be tokenized Returns: xq: vector from the tokenizer, representing the original prompt """ # inputs = tokenizer(prompt, return_tensors='pt') # out = clip.get_text_features(**inputs) input_ids, xq = extract_text_feature(prompt, model, processor) input_ids = input_ids.detach().cpu().numpy() xq = xq.detach().cpu().numpy() return input_ids, xq def tune(clf, X, y, iters=2): """ Train the Zero-shot Classifier Args: X (numpy.ndarray): Input vectors (retreived vectors) y (list of floats or numpy.ndarray): Scores given by user iters (int, optional): iterations of updates to be run """ assert len(X) == len(y) # train the classifier clf.fit(X, y, iters=iters) # extract new vector return clf.get_weights() class Classifier: """Multi-Class Zero-shot Classifier This Classifier provides proxy regarding to the user's reaction to the probed images. The proxy will replace the original query vector generated by prompted vector and finally give the user a satisfying retrieval result. This can be commonly seen in a recommendation system. The classifier will recommend more precise result as it accumulating user's activity. This is a multiclass classifier. For N queries it will set the all queries to the first-N classes and the last one takes the negative one. """ def __init__(self, xq: list): init_weight = torch.Tensor(xq) self.num_class = xq.shape[0] DIMS = xq.shape[1] # note that the bias is ignored, as we only focus on the inner product result self.model = torch.nn.Linear(DIMS, self.num_class, bias=False) # convert initial query `xq` to tensor parameter to init weights self.model.weight = torch.nn.Parameter(init_weight) # init loss and optimizer self.loss = torch.nn.BCEWithLogitsLoss() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) def fit(self, X: list, y: list, iters: int = 5): # convert X and y to tensor X = torch.Tensor(X) X /= torch.norm(X, p=2, dim=-1, keepdim=True) y = torch.Tensor(y).long() # Generate labels for binary classification and ignore outbound labels non_ind = y > self.num_class y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float() y[non_ind] = 0 for i in range(iters): # zero gradients self.optimizer.zero_grad() # Normalize the weight before inference # This will constrain the gradient or you will have an explosion on query vector self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True) # forward pass out = self.model(X) # compute loss loss = self.loss(out, y) # backward pass loss.backward() # update weights self.optimizer.step() def get_weights(self): xq = self.model.weight.detach().numpy() return xq class SplitLayer(torch.nn.Module): def forward(self, x): return torch.split(x, 1, dim=-1)