Getting logits and kv_cache of moondream

#24
by sachin - opened

I am trying to see if moondream can answer multiple yes or no questions regarding the same image. I was able to do this in phi-3 using the following snippet below. Was wondering if I could do something similar (efficiently) via moon dream.

Trying to do: In [12]: model(enc_image, "Describe the image.", tokenizer) resulted in NotImplementedError: Module [Moondream] is missing the required "forward" function

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig
from transformers.image_utils import load_image

MODEL_ID = "microsoft/Phi-3-vision-128k-instruct"

device = "cuda" # the device to load the model onto
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config, 
    device_map="auto",
    trust_remote_code=True, 
    _attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(
    MODEL_ID, 
    trust_remote_code=True
)
image = load_image("https://sm.ign.com/t/ign_ap/review/d/deadpool-r/deadpool-review_2s7s.1200.jpg")

tokenizer = processor.tokenizer
yes_id = torch.tensor(tokenizer.encode("yes", add_special_tokens=False)[-1]).unsqueeze(0).unsqueeze(0)
no_id = torch.tensor(tokenizer.encode("no", add_special_tokens=False)[-1]).unsqueeze(0).unsqueeze(0)
if not yes_id.shape == torch.Size([1, 1]):
    raise ValueError("yes id is multiple tokens")
if not no_id.shape == torch.Size([1, 1]):
    raise ValueError("no id is multiple tokens")

prompt = "<|user|>\nYou are an experienced marvel fan. Only answer yes or no. <|image_1|>\n Does this image contain "
prompt_end = "<|end|>\n<|assistant|>\n"
root_inputs = processor(text=prompt, images=[image], padding="longest", return_tensors="pt").to(device)
with torch.inference_mode():
    kv_cache = model(**root_inputs, return_dict=True).past_key_values

with torch.inference_mode():
    probs_iterative = []
    class_names = ["a superhero", "a fireplace", "a bear", "random thought"]
    for class_name in class_names:
        inputs = processor(text = [class_name + prompt_end], padding=True, truncation=True, return_tensors="pt").to(device)
        inputs["attention_mask"] = torch.cat([root_inputs["attention_mask"], inputs["attention_mask"]], dim=-1)
        outputs = model(**inputs, past_key_values=kv_cache, return_dict=True)
        logits = torch.tensor([outputs.logits[-1, -1, yes_id], outputs.logits[-1, -1, no_id]], device=device)
        probs_iterative.append(F.softmax(logits, dim=-1))
        
        print(f"The probability of seeing {class_name} is {probs_iterative[-1][0].item():.4f}")

Sign up or log in to comment