akshit-g commited on
Commit
af62a64
·
1 Parent(s): 6c84df1

add : app.py

Browse files
moondream/__init__.py ADDED
File without changes
moondream/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (151 Bytes). View file
 
moondream/eval/docvqa.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import editdistance
2
+ from datasets import load_dataset
3
+ from tqdm import tqdm
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ from ..hf import detect_device
7
+
8
+ MODEL_ID = "vikhyatk/moondream2"
9
+ DEVICE, DTYPE = detect_device()
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
12
+ moondream = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID,
14
+ trust_remote_code=True,
15
+ attn_implementation="flash_attention_2",
16
+ torch_dtype=DTYPE,
17
+ device_map={"": DEVICE},
18
+ )
19
+ moondream.eval()
20
+
21
+
22
+ def get_anls(s1, s2):
23
+ s1 = s1.lower().strip()
24
+ s2 = s2.lower().strip()
25
+ iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
26
+ anls = iou if iou >= 0.5 else 0.0
27
+ return anls
28
+
29
+
30
+ docvqa_val = load_dataset("vikhyatk/docvqa", split="validation")
31
+
32
+ scores = []
33
+ for row in tqdm(docvqa_val):
34
+ image = row["image"]
35
+ enc_image = moondream.encode_image(image)
36
+ for qa in row["qa"]:
37
+ question = qa["question"]
38
+ answers = qa["answers"]
39
+ prompt = f"{question}\nAnswer briefly with a single word or phrase."
40
+
41
+ model_answer = moondream.answer_question(enc_image, prompt, tokenizer)
42
+ anls = max(get_anls(model_answer, gt) for gt in answers)
43
+ scores.append(anls)
44
+
45
+ print("ANLS:", sum(scores) / len(scores))
moondream/eval/naturalbench.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tqdm import tqdm
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ from ..hf import detect_device
6
+
7
+ MODEL_ID = "vikhyatk/moondream2"
8
+ DEVICE, DTYPE = detect_device()
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
11
+ moondream = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_ID,
13
+ trust_remote_code=True,
14
+ attn_implementation="flash_attention_2",
15
+ torch_dtype=DTYPE,
16
+ device_map={"": DEVICE},
17
+ )
18
+ moondream.eval()
19
+
20
+ # Yes, the benchmark test set is stored in the 'train' split...
21
+ dataset = load_dataset("BaiqiL/NaturalBench", split="train")
22
+
23
+ acc = []
24
+ q_acc = []
25
+ i_acc = []
26
+ g_acc = []
27
+
28
+ for row in tqdm(dataset):
29
+ if row["Question_Type"] == "yes_no":
30
+ suffix = " Answer yes or no."
31
+ else:
32
+ suffix = ""
33
+
34
+ answers = moondream.batch_answer(
35
+ images=[row["Image_0"], row["Image_1"], row["Image_0"], row["Image_1"]],
36
+ prompts=[
37
+ row["Question_0"] + suffix,
38
+ row["Question_0"] + suffix,
39
+ row["Question_1"] + suffix,
40
+ row["Question_1"] + suffix,
41
+ ],
42
+ tokenizer=tokenizer,
43
+ )
44
+
45
+ expected = [
46
+ row["Image_0_Question_0"],
47
+ row["Image_1_Question_0"],
48
+ row["Image_0_Question_1"],
49
+ row["Image_1_Question_1"],
50
+ ]
51
+
52
+ acc.append(answers[0] == expected[0])
53
+ acc.append(answers[1] == expected[1])
54
+ acc.append(answers[2] == expected[2])
55
+ acc.append(answers[3] == expected[3])
56
+
57
+ i_acc.append(answers[0] == expected[0] and answers[2] == expected[2])
58
+ i_acc.append(answers[1] == expected[1] and answers[3] == expected[3])
59
+
60
+ q_acc.append(answers[0] == expected[0] and answers[1] == expected[1])
61
+ q_acc.append(answers[2] == expected[2] and answers[3] == expected[3])
62
+
63
+ g_acc.append(
64
+ answers[0] == expected[0]
65
+ and answers[1] == expected[1]
66
+ and answers[2] == expected[2]
67
+ and answers[3] == expected[3]
68
+ )
69
+
70
+
71
+ print("Overall Accuracy:", sum(acc) / len(acc))
72
+ print("Image Accuracy:", sum(i_acc) / len(i_acc))
73
+ print("Question Accuracy:", sum(q_acc) / len(q_acc))
74
+ print("Group Accuracy:", sum(g_acc) / len(g_acc))
moondream/eval/pope.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tqdm import tqdm
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ from ..hf import detect_device
6
+
7
+ MODEL_ID = "vikhyatk/moondream2"
8
+ DEVICE, DTYPE = detect_device()
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
11
+ moondream = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_ID,
13
+ trust_remote_code=True,
14
+ attn_implementation="flash_attention_2",
15
+ torch_dtype=DTYPE,
16
+ device_map={"": DEVICE},
17
+ )
18
+ moondream.eval()
19
+
20
+ pope_dataset = load_dataset("vikhyatk/POPE", split="test")
21
+
22
+ stats = {
23
+ "random": (0, 0),
24
+ "popular": (0, 0),
25
+ "adversarial": (0, 0),
26
+ }
27
+ for row in tqdm(pope_dataset):
28
+ image = row["image"]
29
+ enc_image = moondream.encode_image(image)
30
+ for split in ["adversarial", "popular", "random"]:
31
+ for qa in row[split]:
32
+ question = qa["question"]
33
+ answer = qa["answer"]
34
+ prompt = f"{question}\nAnswer yes or no."
35
+ model_answer = moondream.answer_question(enc_image, prompt, tokenizer)
36
+ if model_answer.lower() == answer.lower():
37
+ stats[split] = (stats[split][0] + 1, stats[split][1] + 1)
38
+ else:
39
+ stats[split] = (stats[split][0], stats[split][1] + 1)
40
+
41
+ print(
42
+ "Random:",
43
+ stats["random"][0],
44
+ "/",
45
+ stats["random"][1],
46
+ ":",
47
+ stats["random"][0] * 100.0 / stats["random"][1],
48
+ )
49
+ print(
50
+ "Popular:",
51
+ stats["popular"][0],
52
+ "/",
53
+ stats["popular"][1],
54
+ ":",
55
+ stats["popular"][0] * 100.0 / stats["popular"][1],
56
+ )
57
+ print(
58
+ "Adversarial:",
59
+ stats["adversarial"][0],
60
+ "/",
61
+ stats["adversarial"][1],
62
+ ":",
63
+ stats["adversarial"][0] * 100.0 / stats["adversarial"][1],
64
+ )
moondream/eval/tallyqa.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Expects Visual Genome to be downloaded to `data/vg` and the TallyQA test set
2
+ # to be present at `data/tallyqa/test.json`.
3
+ #
4
+ # Steps to download Visual Genome and TallyQA:
5
+ #
6
+ # mkdir -p data/vg/VG_100K
7
+ # mkdir -p data/vg/VG_100K_2
8
+ # mkdir -p data/tallyqa
9
+ # wget -P data/vg/VG_100K_2/ https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip
10
+ # wget -P data/vg/VG_100K/ https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip
11
+ # wget -P data/tallyqa/ https://github.com/manoja328/TallyQA_dataset/raw/master/tallyqa.zip
12
+ # unzip data/vg/VG_100K_2/images2.zip -d data/vg/
13
+ # unzip data/vg/VG_100K/images.zip -d data/vg/
14
+ # unzip data/tallyqa/tallyqa.zip -d data/tallyqa/
15
+ # rm data/vg/VG_100K_2/images2.zip
16
+ # rm data/vg/VG_100K/images.zip
17
+ # rm data/tallyqa/tallyqa.zip
18
+
19
+ import json
20
+
21
+ from PIL import Image
22
+ from tqdm import tqdm
23
+ from transformers import AutoTokenizer
24
+
25
+ from ..hf import Moondream, detect_device
26
+
27
+ BATCH_SIZE = 16
28
+ DEVICE, DTYPE = detect_device()
29
+
30
+ model_id = "vikhyatk/moondream2"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
32
+ model = Moondream.from_pretrained(
33
+ model_id,
34
+ attn_implementation="flash_attention_2",
35
+ torch_dtype=DTYPE,
36
+ device_map={"": DEVICE},
37
+ )
38
+ model.eval()
39
+
40
+ total = 0
41
+ total_simple = 0
42
+ correct = 0
43
+ correct_simple = 0
44
+
45
+ # Iterate over tallyqa_test in batches of BATCH_SIZE
46
+ tallyqa_test = json.load(open("data/tallyqa/test.json"))
47
+ for i in tqdm(range(0, len(tallyqa_test), BATCH_SIZE)):
48
+ batch = tallyqa_test[i : i + BATCH_SIZE]
49
+
50
+ images = [Image.open(f"data/vg/{item['image']}") for item in batch]
51
+ questions = [
52
+ item["question"] + " Answer in a word or phrase only." for item in batch
53
+ ]
54
+
55
+ answers = model.batch_answer(
56
+ images=images, prompts=questions, tokenizer=tokenizer, max_new_tokens=10
57
+ )
58
+
59
+ for answer, item in zip(answers, batch):
60
+ is_simple = item["issimple"]
61
+ is_correct = 1 if str(item["answer"]) == answer else 0
62
+
63
+ total += 1
64
+ correct += is_correct
65
+ if is_simple:
66
+ total_simple += 1
67
+ correct_simple += is_correct
68
+
69
+ print(
70
+ f"Simple: {total_simple}, Correct: {correct_simple}, Accuracy: {correct_simple*100.0/total_simple}"
71
+ )
72
+ print(f"Total: {total}, Correct: {correct}, Accuracy: {correct*100.0/total}")
moondream/hf/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .moondream import Moondream
2
+ from .util import LATEST_REVISION, detect_device
moondream/hf/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (272 Bytes). View file
 
moondream/hf/__pycache__/configuration_moondream.cpython-312.pyc ADDED
Binary file (3.58 kB). View file
 
moondream/hf/__pycache__/fourier_features.cpython-312.pyc ADDED
Binary file (1.37 kB). View file
 
moondream/hf/__pycache__/modeling_phi.cpython-312.pyc ADDED
Binary file (62.2 kB). View file
 
moondream/hf/__pycache__/moondream.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
moondream/hf/__pycache__/region_model.cpython-312.pyc ADDED
Binary file (4.48 kB). View file
 
moondream/hf/__pycache__/util.cpython-312.pyc ADDED
Binary file (948 Bytes). View file
 
moondream/hf/__pycache__/vision_encoder.cpython-312.pyc ADDED
Binary file (16.8 kB). View file
 
moondream/hf/configuration_moondream.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class PhiConfig(PretrainedConfig):
5
+ model_type = "phi"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=51200,
11
+ hidden_size=2048,
12
+ intermediate_size=8192,
13
+ num_hidden_layers=24,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ resid_pdrop=0.0,
17
+ embd_pdrop=0.0,
18
+ attention_dropout=0.0,
19
+ hidden_act="gelu_new",
20
+ max_position_embeddings=2048,
21
+ initializer_range=0.02,
22
+ layer_norm_eps=1e-5,
23
+ use_cache=True,
24
+ tie_word_embeddings=False,
25
+ rope_theta=10000.0,
26
+ rope_scaling=None,
27
+ partial_rotary_factor=0.5,
28
+ bos_token_id=1,
29
+ eos_token_id=2,
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.hidden_size = hidden_size
34
+ self.intermediate_size = intermediate_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+
38
+ if num_key_value_heads is None:
39
+ num_key_value_heads = num_attention_heads
40
+
41
+ self.num_key_value_heads = num_key_value_heads
42
+ self.resid_pdrop = resid_pdrop
43
+ self.embd_pdrop = embd_pdrop
44
+ self.attention_dropout = attention_dropout
45
+ self.hidden_act = hidden_act
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.initializer_range = initializer_range
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.use_cache = use_cache
50
+ self.rope_theta = rope_theta
51
+ self.rope_scaling = rope_scaling
52
+ self.partial_rotary_factor = partial_rotary_factor
53
+ self._rope_scaling_validation()
54
+
55
+ super().__init__(
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
61
+
62
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
63
+ def _rope_scaling_validation(self):
64
+ """
65
+ Validate the `rope_scaling` configuration.
66
+ """
67
+ if self.rope_scaling is None:
68
+ return
69
+
70
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
71
+ raise ValueError(
72
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
73
+ f"got {self.rope_scaling}"
74
+ )
75
+ rope_scaling_type = self.rope_scaling.get("type", None)
76
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
77
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
78
+ raise ValueError(
79
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
80
+ )
81
+ if (
82
+ rope_scaling_factor is None
83
+ or not isinstance(rope_scaling_factor, float)
84
+ or rope_scaling_factor <= 1.0
85
+ ):
86
+ raise ValueError(
87
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
88
+ )
89
+
90
+
91
+ class MoondreamConfig(PretrainedConfig):
92
+ model_type = "moondream1"
93
+
94
+ def __init__(self, **kwargs):
95
+ self.text_config = PhiConfig(**kwargs.pop("text_config", {}))
96
+ super().__init__(**kwargs)
moondream/hf/fourier_features.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/crowsonkb/k-diffusion/blob/transformer-model-v2/k_diffusion/layers.py
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class FourierFeatures(nn.Module):
10
+ def __init__(self, in_features, out_features, std=1.0):
11
+ super().__init__()
12
+ assert out_features % 2 == 0
13
+ self.register_buffer(
14
+ "weight", torch.randn([out_features // 2, in_features]) * std
15
+ )
16
+
17
+ def forward(self, input):
18
+ f = 2 * math.pi * input @ self.weight.T
19
+ return torch.cat([f.cos(), f.sin()], dim=-1)
moondream/hf/modeling_phi.py ADDED
@@ -0,0 +1,1477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
28
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPast,
31
+ CausalLMOutputWithPast,
32
+ )
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import (
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ get_torch_version,
38
+ is_flash_attn_2_available,
39
+ is_flash_attn_greater_or_equal_2_10,
40
+ is_torchdynamo_compiling,
41
+ logging,
42
+ replace_return_docstrings,
43
+ )
44
+
45
+ from .configuration_moondream import PhiConfig
46
+
47
+ if is_flash_attn_2_available():
48
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CONFIG_FOR_DOC = "PhiConfig"
54
+
55
+
56
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
57
+ def _prepare_4d_causal_attention_mask_with_cache_position(
58
+ attention_mask: torch.Tensor,
59
+ sequence_length: int,
60
+ target_length: int,
61
+ dtype: torch.dtype,
62
+ device: torch.device,
63
+ min_dtype: float,
64
+ cache_position: torch.Tensor,
65
+ batch_size: int,
66
+ ):
67
+ """
68
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
69
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
70
+
71
+ Args:
72
+ attention_mask (`torch.Tensor`):
73
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
74
+ sequence_length (`int`):
75
+ The sequence length being processed.
76
+ target_length (`int`):
77
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
78
+ dtype (`torch.dtype`):
79
+ The dtype to use for the 4D attention mask.
80
+ device (`torch.device`):
81
+ The device to plcae the 4D attention mask on.
82
+ min_dtype (`float`):
83
+ The minimum value representable with the dtype `dtype`.
84
+ cache_position (`torch.Tensor`):
85
+ Indices depicting the position of the input sequence tokens in the sequence.
86
+ batch_size (`torch.Tensor`):
87
+ Batch size.
88
+ """
89
+ if attention_mask is not None and attention_mask.dim() == 4:
90
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
91
+ causal_mask = attention_mask
92
+ else:
93
+ causal_mask = torch.full(
94
+ (sequence_length, target_length),
95
+ fill_value=min_dtype,
96
+ dtype=dtype,
97
+ device=device,
98
+ )
99
+ if sequence_length != 1:
100
+ causal_mask = torch.triu(causal_mask, diagonal=1)
101
+ causal_mask *= torch.arange(
102
+ target_length, device=device
103
+ ) > cache_position.reshape(-1, 1)
104
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
105
+ if attention_mask is not None:
106
+ causal_mask = (
107
+ causal_mask.clone()
108
+ ) # copy to contiguous memory for in-place edit
109
+ mask_length = attention_mask.shape[-1]
110
+ padding_mask = (
111
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
112
+ )
113
+ padding_mask = padding_mask == 0
114
+ causal_mask[:, :, :, :mask_length] = causal_mask[
115
+ :, :, :, :mask_length
116
+ ].masked_fill(padding_mask, min_dtype)
117
+
118
+ return causal_mask
119
+
120
+
121
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
122
+ class PhiRotaryEmbedding(nn.Module):
123
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
124
+ super().__init__()
125
+
126
+ self.dim = dim
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.base = base
129
+ inv_freq = 1.0 / (
130
+ self.base
131
+ ** (
132
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
133
+ / self.dim
134
+ )
135
+ )
136
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
137
+
138
+ # Build here to make `torch.jit.trace` work.
139
+ self._set_cos_sin_cache(
140
+ seq_len=max_position_embeddings,
141
+ device=self.inv_freq.device,
142
+ dtype=torch.get_default_dtype(),
143
+ )
144
+
145
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
146
+ self.max_seq_len_cached = seq_len
147
+ t = torch.arange(
148
+ self.max_seq_len_cached, device=device, dtype=torch.int64
149
+ ).type_as(self.inv_freq)
150
+
151
+ freqs = torch.outer(t, self.inv_freq)
152
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
155
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
156
+
157
+ def forward(self, x, seq_len=None):
158
+ # x: [bs, num_attention_heads, seq_len, head_size]
159
+ if seq_len > self.max_seq_len_cached:
160
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
161
+
162
+ return (
163
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
164
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
165
+ )
166
+
167
+
168
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
169
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
170
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
+
172
+ def __init__(
173
+ self,
174
+ dim,
175
+ max_position_embeddings=2048,
176
+ base=10000,
177
+ device=None,
178
+ scaling_factor=1.0,
179
+ ):
180
+ self.scaling_factor = scaling_factor
181
+ super().__init__(dim, max_position_embeddings, base, device)
182
+
183
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
184
+ self.max_seq_len_cached = seq_len
185
+ t = torch.arange(
186
+ self.max_seq_len_cached, device=device, dtype=torch.int64
187
+ ).type_as(self.inv_freq)
188
+ t = t / self.scaling_factor
189
+
190
+ freqs = torch.outer(t, self.inv_freq)
191
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
192
+ emb = torch.cat((freqs, freqs), dim=-1)
193
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
194
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
195
+
196
+
197
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
198
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
199
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
200
+
201
+ def __init__(
202
+ self,
203
+ dim,
204
+ max_position_embeddings=2048,
205
+ base=10000,
206
+ device=None,
207
+ scaling_factor=1.0,
208
+ ):
209
+ self.scaling_factor = scaling_factor
210
+ super().__init__(dim, max_position_embeddings, base, device)
211
+
212
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
213
+ self.max_seq_len_cached = seq_len
214
+
215
+ if seq_len > self.max_position_embeddings:
216
+ base = self.base * (
217
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
218
+ - (self.scaling_factor - 1)
219
+ ) ** (self.dim / (self.dim - 2))
220
+ inv_freq = 1.0 / (
221
+ base
222
+ ** (
223
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
224
+ / self.dim
225
+ )
226
+ )
227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
228
+
229
+ t = torch.arange(
230
+ self.max_seq_len_cached, device=device, dtype=torch.int64
231
+ ).type_as(self.inv_freq)
232
+
233
+ freqs = torch.outer(t, self.inv_freq)
234
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
235
+ emb = torch.cat((freqs, freqs), dim=-1)
236
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
237
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
238
+
239
+
240
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
241
+ def rotate_half(x):
242
+ """Rotates half the hidden dims of the input."""
243
+ x1 = x[..., : x.shape[-1] // 2]
244
+ x2 = x[..., x.shape[-1] // 2 :]
245
+ return torch.cat((-x2, x1), dim=-1)
246
+
247
+
248
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
249
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
250
+ """Applies Rotary Position Embedding to the query and key tensors.
251
+
252
+ Args:
253
+ q (`torch.Tensor`): The query tensor.
254
+ k (`torch.Tensor`): The key tensor.
255
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
256
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
257
+ position_ids (`torch.Tensor`):
258
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
259
+ used to pass offsetted position ids when working with a KV-cache.
260
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
261
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
262
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
263
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
264
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
265
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
266
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
267
+ Returns:
268
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
269
+ """
270
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
271
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
272
+ q_embed = (q * cos) + (rotate_half(q) * sin)
273
+ k_embed = (k * cos) + (rotate_half(k) * sin)
274
+ return q_embed, k_embed
275
+
276
+
277
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
278
+ class PhiMLP(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.config = config
282
+ self.activation_fn = ACT2FN[config.hidden_act]
283
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
284
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
285
+
286
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287
+ hidden_states = self.fc1(hidden_states)
288
+ hidden_states = self.activation_fn(hidden_states)
289
+ hidden_states = self.fc2(hidden_states)
290
+ return hidden_states
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
294
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
295
+ """
296
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
297
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
298
+ """
299
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
300
+ if n_rep == 1:
301
+ return hidden_states
302
+ hidden_states = hidden_states[:, :, None, :, :].expand(
303
+ batch, num_key_value_heads, n_rep, slen, head_dim
304
+ )
305
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
306
+
307
+
308
+ class PhiAttention(nn.Module):
309
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
310
+
311
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
312
+ super().__init__()
313
+ self.config = config
314
+ self.layer_idx = layer_idx
315
+ if layer_idx is None:
316
+ logger.warning_once(
317
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
318
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
319
+ "when creating this class."
320
+ )
321
+
322
+ self.attention_dropout = config.attention_dropout
323
+ self.hidden_size = config.hidden_size
324
+ self.num_heads = config.num_attention_heads
325
+ self.head_dim = self.hidden_size // self.num_heads
326
+ self.num_key_value_heads = config.num_key_value_heads
327
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
328
+ self.max_position_embeddings = config.max_position_embeddings
329
+ self.rope_theta = config.rope_theta
330
+ self.partial_rotary_factor = config.partial_rotary_factor
331
+ self.is_causal = True
332
+
333
+ if (self.head_dim * self.num_heads) != self.hidden_size:
334
+ raise ValueError(
335
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
336
+ f" and `num_heads`: {self.num_heads})."
337
+ )
338
+
339
+ self.Wqkv = nn.Linear(
340
+ self.hidden_size, 3 * self.num_heads * self.head_dim, bias=True
341
+ )
342
+ self.out_proj = nn.Linear(
343
+ self.num_heads * self.head_dim, self.hidden_size, bias=True
344
+ )
345
+
346
+ self._init_rope()
347
+
348
+ def _init_rope(self):
349
+ if self.config.rope_scaling is None:
350
+ self.rotary_emb = PhiRotaryEmbedding(
351
+ int(self.partial_rotary_factor * self.head_dim),
352
+ max_position_embeddings=self.max_position_embeddings,
353
+ base=self.rope_theta,
354
+ )
355
+ else:
356
+ scaling_type = self.config.rope_scaling["type"]
357
+ scaling_factor = self.config.rope_scaling["factor"]
358
+ if scaling_type == "linear":
359
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
360
+ int(self.partial_rotary_factor * self.head_dim),
361
+ max_position_embeddings=self.max_position_embeddings,
362
+ scaling_factor=scaling_factor,
363
+ base=self.rope_theta,
364
+ )
365
+ elif scaling_type == "dynamic":
366
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
367
+ int(self.partial_rotary_factor * self.head_dim),
368
+ max_position_embeddings=self.max_position_embeddings,
369
+ scaling_factor=scaling_factor,
370
+ base=self.rope_theta,
371
+ )
372
+ else:
373
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
374
+
375
+ def forward(
376
+ self,
377
+ hidden_states: torch.Tensor,
378
+ attention_mask: Optional[torch.Tensor] = None,
379
+ position_ids: Optional[torch.LongTensor] = None,
380
+ past_key_value: Optional[Cache] = None,
381
+ output_attentions: bool = False,
382
+ use_cache: bool = False,
383
+ cache_position: Optional[torch.LongTensor] = None,
384
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
385
+ bsz, q_len, _ = hidden_states.size()
386
+
387
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
388
+ 3, dim=-1
389
+ )
390
+
391
+ query_states = query_states.view(
392
+ bsz, q_len, self.num_heads, self.head_dim
393
+ ).transpose(1, 2)
394
+ key_states = key_states.view(
395
+ bsz, q_len, self.num_key_value_heads, self.head_dim
396
+ ).transpose(1, 2)
397
+ value_states = value_states.view(
398
+ bsz, q_len, self.num_key_value_heads, self.head_dim
399
+ ).transpose(1, 2)
400
+
401
+ kv_seq_len = key_states.shape[-2]
402
+ if past_key_value is not None:
403
+ if self.layer_idx is None:
404
+ raise ValueError(
405
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
406
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
407
+ "with a layer index."
408
+ )
409
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
410
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
411
+
412
+ # Partial rotary embedding
413
+ query_rot, query_pass = (
414
+ query_states[..., : self.rotary_emb.dim],
415
+ query_states[..., self.rotary_emb.dim :],
416
+ )
417
+ key_rot, key_pass = (
418
+ key_states[..., : self.rotary_emb.dim],
419
+ key_states[..., self.rotary_emb.dim :],
420
+ )
421
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
422
+ query_rot, key_rot = apply_rotary_pos_emb(
423
+ query_rot, key_rot, cos, sin, position_ids
424
+ )
425
+
426
+ # [batch_size, seq_length, num_heads, head_dim]
427
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
428
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
429
+
430
+ if past_key_value is not None:
431
+ cache_kwargs = {
432
+ "sin": sin,
433
+ "cos": cos,
434
+ "partial_rotation_size": self.rotary_emb.dim,
435
+ "cache_position": cache_position,
436
+ }
437
+ key_states, value_states = past_key_value.update(
438
+ key_states, value_states, self.layer_idx, cache_kwargs
439
+ )
440
+
441
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
442
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
443
+
444
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
445
+ attn_weights = torch.matmul(
446
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
447
+ ) / math.sqrt(self.head_dim)
448
+
449
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
450
+ raise ValueError(
451
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
452
+ f" {attn_weights.size()}"
453
+ )
454
+
455
+ if attention_mask is not None:
456
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
457
+ attn_weights += causal_mask
458
+
459
+ # upcast attention to fp32
460
+ attn_weights = nn.functional.softmax(
461
+ attn_weights, dim=-1, dtype=torch.float32
462
+ ).to(value_states.dtype)
463
+ attn_weights = nn.functional.dropout(
464
+ attn_weights, p=self.attention_dropout, training=self.training
465
+ )
466
+
467
+ attn_output = torch.matmul(attn_weights, value_states)
468
+
469
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
470
+ raise ValueError(
471
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
472
+ f" {attn_output.size()}"
473
+ )
474
+
475
+ attn_output = attn_output.transpose(1, 2).contiguous()
476
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
477
+
478
+ attn_output = self.out_proj(attn_output)
479
+
480
+ if not output_attentions:
481
+ attn_weights = None
482
+
483
+ return attn_output, attn_weights, past_key_value
484
+
485
+
486
+ class PhiFlashAttention2(PhiAttention):
487
+ """
488
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
489
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
490
+ flash attention and deal with padding tokens in case the input contains any of them.
491
+ """
492
+
493
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
494
+ def __init__(self, *args, **kwargs):
495
+ super().__init__(*args, **kwargs)
496
+
497
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
498
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
499
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
500
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
501
+
502
+ def forward(
503
+ self,
504
+ hidden_states: torch.Tensor,
505
+ attention_mask: Optional[torch.LongTensor] = None,
506
+ position_ids: Optional[torch.LongTensor] = None,
507
+ past_key_value: Optional[Cache] = None,
508
+ output_attentions: bool = False,
509
+ use_cache: bool = False,
510
+ cache_position: Optional[torch.LongTensor] = None,
511
+ **kwargs,
512
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
513
+ # PhiFlashAttention2 attention does not support output_attentions
514
+
515
+ output_attentions = False
516
+
517
+ bsz, q_len, _ = hidden_states.size()
518
+
519
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
520
+ 3, dim=-1
521
+ )
522
+
523
+ # Flash attention requires the input to have the shape
524
+ # batch_size x seq_length x head_dim x hidden_dim
525
+ # therefore we just need to keep the original shape
526
+ query_states = query_states.view(
527
+ bsz, q_len, self.num_heads, self.head_dim
528
+ ).transpose(1, 2)
529
+ key_states = key_states.view(
530
+ bsz, q_len, self.num_key_value_heads, self.head_dim
531
+ ).transpose(1, 2)
532
+ value_states = value_states.view(
533
+ bsz, q_len, self.num_key_value_heads, self.head_dim
534
+ ).transpose(1, 2)
535
+
536
+ kv_seq_len = key_states.shape[-2]
537
+ if past_key_value is not None:
538
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
539
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
540
+
541
+ # Partial rotary embedding
542
+ query_rot, query_pass = (
543
+ query_states[..., : self.rotary_emb.dim],
544
+ query_states[..., self.rotary_emb.dim :],
545
+ )
546
+ key_rot, key_pass = (
547
+ key_states[..., : self.rotary_emb.dim],
548
+ key_states[..., self.rotary_emb.dim :],
549
+ )
550
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
551
+ query_rot, key_rot = apply_rotary_pos_emb(
552
+ query_rot, key_rot, cos, sin, position_ids
553
+ )
554
+
555
+ # [batch_size, seq_length, num_heads, head_dim]
556
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
557
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
558
+
559
+ if past_key_value is not None:
560
+ cache_kwargs = {
561
+ "sin": sin,
562
+ "cos": cos,
563
+ "partial_rotation_size": self.rotary_emb.dim,
564
+ "cache_position": cache_position,
565
+ }
566
+ key_states, value_states = past_key_value.update(
567
+ key_states, value_states, self.layer_idx, cache_kwargs
568
+ )
569
+
570
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
571
+ # to be able to avoid many of these transpose/reshape/view.
572
+ query_states = query_states.transpose(1, 2)
573
+ key_states = key_states.transpose(1, 2)
574
+ value_states = value_states.transpose(1, 2)
575
+
576
+ attn_dropout = self.attention_dropout if self.training else 0.0
577
+
578
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
579
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
580
+ # cast them back in the correct dtype just to be sure everything works as expected.
581
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
582
+ # in fp32.
583
+
584
+ if query_states.dtype == torch.float32:
585
+ if torch.is_autocast_enabled():
586
+ target_dtype = torch.get_autocast_gpu_dtype()
587
+ # Handle the case where the model is quantized
588
+ elif hasattr(self.config, "_pre_quantization_dtype"):
589
+ target_dtype = self.config._pre_quantization_dtype
590
+ else:
591
+ target_dtype = self.q_proj.weight.dtype
592
+
593
+ logger.warning_once(
594
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
595
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
596
+ f" {target_dtype}."
597
+ )
598
+
599
+ query_states = query_states.to(target_dtype)
600
+ key_states = key_states.to(target_dtype)
601
+ value_states = value_states.to(target_dtype)
602
+
603
+ attn_output = _flash_attention_forward(
604
+ query_states,
605
+ key_states,
606
+ value_states,
607
+ attention_mask,
608
+ q_len,
609
+ position_ids=position_ids,
610
+ dropout=attn_dropout,
611
+ softmax_scale=None,
612
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
613
+ is_causal=self.is_causal,
614
+ )
615
+
616
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
617
+ attn_output = self.out_proj(attn_output)
618
+
619
+ if not output_attentions:
620
+ attn_weights = None
621
+
622
+ return attn_output, attn_weights, past_key_value
623
+
624
+
625
+ class PhiSdpaAttention(PhiAttention):
626
+ def __init__(self, *args, **kwargs):
627
+ super().__init__(*args, **kwargs)
628
+ self.require_contiguous_qkv = version.parse(
629
+ get_torch_version()
630
+ ) < version.parse("2.2.0")
631
+
632
+ """
633
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
634
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
635
+ SDPA API.
636
+ """
637
+
638
+ # Adapted from PhiAttention.forward
639
+ def forward(
640
+ self,
641
+ hidden_states: torch.Tensor,
642
+ attention_mask: Optional[torch.Tensor] = None,
643
+ position_ids: Optional[torch.LongTensor] = None,
644
+ past_key_value: Optional[Cache] = None,
645
+ output_attentions: bool = False,
646
+ use_cache: bool = False,
647
+ cache_position: Optional[torch.LongTensor] = None,
648
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
649
+ if output_attentions:
650
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
651
+ logger.warning_once(
652
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
653
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
654
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
655
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
656
+ )
657
+ return super().forward(
658
+ hidden_states=hidden_states,
659
+ attention_mask=attention_mask,
660
+ position_ids=position_ids,
661
+ past_key_value=past_key_value,
662
+ output_attentions=output_attentions,
663
+ use_cache=use_cache,
664
+ )
665
+
666
+ bsz, q_len, _ = hidden_states.size()
667
+
668
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
669
+ 3, dim=-1
670
+ )
671
+
672
+ query_states = query_states.view(
673
+ bsz, q_len, self.num_heads, self.head_dim
674
+ ).transpose(1, 2)
675
+ key_states = key_states.view(
676
+ bsz, q_len, self.num_key_value_heads, self.head_dim
677
+ ).transpose(1, 2)
678
+ value_states = value_states.view(
679
+ bsz, q_len, self.num_key_value_heads, self.head_dim
680
+ ).transpose(1, 2)
681
+
682
+ kv_seq_len = key_states.shape[-2]
683
+ if past_key_value is not None:
684
+ if self.layer_idx is None:
685
+ raise ValueError(
686
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
687
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
688
+ "with a layer index."
689
+ )
690
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
691
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
692
+
693
+ # Partial rotary embedding
694
+ query_rot, query_pass = (
695
+ query_states[..., : self.rotary_emb.dim],
696
+ query_states[..., self.rotary_emb.dim :],
697
+ )
698
+ key_rot, key_pass = (
699
+ key_states[..., : self.rotary_emb.dim],
700
+ key_states[..., self.rotary_emb.dim :],
701
+ )
702
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
703
+ query_rot, key_rot = apply_rotary_pos_emb(
704
+ query_rot, key_rot, cos, sin, position_ids
705
+ )
706
+
707
+ # [batch_size, seq_length, num_heads, head_dim]
708
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
709
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
710
+
711
+ if past_key_value is not None:
712
+ cache_kwargs = {
713
+ "sin": sin,
714
+ "cos": cos,
715
+ "partial_rotation_size": self.rotary_emb.dim,
716
+ "cache_position": cache_position,
717
+ }
718
+ key_states, value_states = past_key_value.update(
719
+ key_states, value_states, self.layer_idx, cache_kwargs
720
+ )
721
+
722
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
723
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
724
+
725
+ causal_mask = attention_mask
726
+ if attention_mask is not None:
727
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
728
+
729
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
730
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
731
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
732
+ if (
733
+ self.require_contiguous_qkv
734
+ and query_states.device.type == "cuda"
735
+ and attention_mask is not None
736
+ ):
737
+ query_states = query_states.contiguous()
738
+ key_states = key_states.contiguous()
739
+ value_states = value_states.contiguous()
740
+
741
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
742
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
743
+ is_causal = True if causal_mask is None and q_len > 1 else False
744
+
745
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
746
+ query_states,
747
+ key_states,
748
+ value_states,
749
+ attn_mask=causal_mask,
750
+ dropout_p=self.attention_dropout if self.training else 0.0,
751
+ is_causal=is_causal,
752
+ )
753
+
754
+ attn_output = attn_output.transpose(1, 2).contiguous()
755
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
756
+
757
+ attn_output = self.out_proj(attn_output)
758
+
759
+ return attn_output, None, past_key_value
760
+
761
+
762
+ PHI_ATTENTION_CLASSES = {
763
+ "eager": PhiAttention,
764
+ "flash_attention_2": PhiFlashAttention2,
765
+ "sdpa": PhiSdpaAttention,
766
+ }
767
+
768
+
769
+ class PhiDecoderLayer(nn.Module):
770
+ def __init__(self, config: PhiConfig, layer_idx: int):
771
+ super().__init__()
772
+ self.mixer = PHI_ATTENTION_CLASSES[config._attn_implementation](
773
+ config, layer_idx=layer_idx
774
+ )
775
+ self.mlp = PhiMLP(config)
776
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
777
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
778
+
779
+ def forward(
780
+ self,
781
+ hidden_states: torch.Tensor,
782
+ attention_mask: Optional[torch.Tensor] = None,
783
+ position_ids: Optional[torch.LongTensor] = None,
784
+ output_attentions: Optional[bool] = False,
785
+ use_cache: Optional[bool] = False,
786
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
787
+ cache_position: Optional[torch.LongTensor] = None,
788
+ **kwargs,
789
+ ) -> Tuple[
790
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
791
+ ]:
792
+ """
793
+ Args:
794
+ hidden_states (`torch.FloatTensor`):
795
+ input to the layer of shape `(batch, seq_len, embed_dim)`
796
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
797
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
798
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
799
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
800
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
801
+ output_attentions (`bool`, *optional*):
802
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
803
+ returned tensors for more detail.
804
+ use_cache (`bool`, *optional*):
805
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
806
+ (see `past_key_values`).
807
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
808
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
809
+ Indices depicting the position of the input sequence tokens in the sequence
810
+ kwargs (`dict`, *optional*):
811
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
812
+ into the model
813
+ """
814
+
815
+ residual = hidden_states
816
+
817
+ hidden_states = self.ln(hidden_states)
818
+
819
+ # Self Attention
820
+ attn_outputs, self_attn_weights, present_key_value = self.mixer(
821
+ hidden_states=hidden_states,
822
+ attention_mask=attention_mask,
823
+ position_ids=position_ids,
824
+ past_key_value=past_key_value,
825
+ output_attentions=output_attentions,
826
+ use_cache=use_cache,
827
+ cache_position=cache_position,
828
+ )
829
+ attn_outputs = self.resid_dropout(attn_outputs)
830
+
831
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
832
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
833
+ outputs = (hidden_states,)
834
+
835
+ if output_attentions:
836
+ outputs += (self_attn_weights,)
837
+
838
+ if use_cache:
839
+ outputs += (present_key_value,)
840
+
841
+ return outputs
842
+
843
+
844
+ PHI_START_DOCSTRING = r"""
845
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
846
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
847
+ etc.)
848
+
849
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
850
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
851
+ and behavior.
852
+
853
+ Parameters:
854
+ config ([`PhiConfig`]):
855
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
856
+ load the weights associated with the model, only the configuration. Check out the
857
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
858
+ """
859
+
860
+
861
+ @add_start_docstrings(
862
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
863
+ PHI_START_DOCSTRING,
864
+ )
865
+ class PhiPreTrainedModel(PreTrainedModel):
866
+ config_class = PhiConfig
867
+ base_model_prefix = "model"
868
+ supports_gradient_checkpointing = True
869
+ _no_split_modules = ["PhiDecoderLayer"]
870
+ _skip_keys_device_placement = "past_key_values"
871
+ _supports_flash_attn_2 = True
872
+ _supports_sdpa = True
873
+ _supports_cache_class = True
874
+
875
+ def _init_weights(self, module):
876
+ std = self.config.initializer_range
877
+ if isinstance(module, nn.Linear):
878
+ module.weight.data.normal_(mean=0.0, std=std)
879
+ if module.bias is not None:
880
+ module.bias.data.zero_()
881
+ elif isinstance(module, nn.Embedding):
882
+ module.weight.data.normal_(mean=0.0, std=std)
883
+ if module.padding_idx is not None:
884
+ module.weight.data[module.padding_idx].zero_()
885
+
886
+
887
+ class Embedding(nn.Module):
888
+ def __init__(self, config: PhiConfig):
889
+ super().__init__()
890
+ self.wte = nn.Embedding(
891
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
892
+ )
893
+
894
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
895
+ return self.wte(input_ids)
896
+
897
+
898
+ PHI_INPUTS_DOCSTRING = r"""
899
+ Args:
900
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
+ it.
903
+
904
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
+ [`PreTrainedTokenizer.__call__`] for details.
906
+
907
+ [What are input IDs?](../glossary#input-ids)
908
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
+
911
+ - 1 for tokens that are **not masked**,
912
+ - 0 for tokens that are **masked**.
913
+
914
+ [What are attention masks?](../glossary#attention-mask)
915
+
916
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
+ `past_key_values`).
921
+
922
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
+ information on the default strategy.
925
+
926
+ - 1 indicates the head is **not masked**,
927
+ - 0 indicates the head is **masked**.
928
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
+ config.n_positions - 1]`.
931
+
932
+ [What are position IDs?](../glossary#position-ids)
933
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
+
938
+ Two formats are allowed:
939
+ - a [`~cache_utils.Cache`] instance;
940
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
+ cache format.
943
+
944
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
+ legacy cache format will be returned.
946
+
947
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
+ of shape `(batch_size, sequence_length)`.
950
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
+ model's internal embedding lookup matrix.
954
+ use_cache (`bool`, *optional*):
955
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
+ `past_key_values`).
957
+ output_attentions (`bool`, *optional*):
958
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
+ tensors for more detail.
960
+ output_hidden_states (`bool`, *optional*):
961
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
+ more detail.
963
+ return_dict (`bool`, *optional*):
964
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
+ the complete sequence length.
969
+ """
970
+
971
+
972
+ @add_start_docstrings(
973
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
+ PHI_START_DOCSTRING,
975
+ )
976
+ class PhiModel(PhiPreTrainedModel):
977
+ """
978
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
979
+
980
+ Args:
981
+ config: PhiConfig
982
+ """
983
+
984
+ def __init__(self, config: PhiConfig):
985
+ super().__init__(config)
986
+ self.padding_idx = config.pad_token_id
987
+ self.vocab_size = config.vocab_size
988
+
989
+ self.embd = Embedding(config)
990
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
991
+ self.h = nn.ModuleList(
992
+ [
993
+ PhiDecoderLayer(config, layer_idx)
994
+ for layer_idx in range(config.num_hidden_layers)
995
+ ]
996
+ )
997
+
998
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
+ self._use_sdpa = config._attn_implementation == "sdpa"
1000
+
1001
+ self.gradient_checkpointing = False
1002
+ # Initialize weights and apply final processing
1003
+ self.post_init()
1004
+
1005
+ def get_input_embeddings(self):
1006
+ return self.embd.wte
1007
+
1008
+ def set_input_embeddings(self, value):
1009
+ self.embd.wte = value
1010
+
1011
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
+ def forward(
1013
+ self,
1014
+ input_ids: torch.LongTensor = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
1016
+ position_ids: Optional[torch.LongTensor] = None,
1017
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1019
+ use_cache: Optional[bool] = None,
1020
+ output_attentions: Optional[bool] = None,
1021
+ output_hidden_states: Optional[bool] = None,
1022
+ return_dict: Optional[bool] = None,
1023
+ cache_position: Optional[torch.LongTensor] = None,
1024
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
+ output_attentions = (
1026
+ output_attentions
1027
+ if output_attentions is not None
1028
+ else self.config.output_attentions
1029
+ )
1030
+ output_hidden_states = (
1031
+ output_hidden_states
1032
+ if output_hidden_states is not None
1033
+ else self.config.output_hidden_states
1034
+ )
1035
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
+
1037
+ return_dict = (
1038
+ return_dict if return_dict is not None else self.config.use_return_dict
1039
+ )
1040
+
1041
+ if (input_ids is None) ^ (inputs_embeds is not None):
1042
+ raise ValueError(
1043
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
+ )
1045
+
1046
+ if self.gradient_checkpointing and self.training:
1047
+ if use_cache:
1048
+ logger.warning_once(
1049
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1050
+ )
1051
+ use_cache = False
1052
+
1053
+ use_legacy_cache = False
1054
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
+ use_legacy_cache = True
1056
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
+ logger.warning_once(
1058
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
1060
+ )
1061
+
1062
+ if inputs_embeds is None:
1063
+ inputs_embeds = self.embd(input_ids)
1064
+
1065
+ if cache_position is None:
1066
+ past_seen_tokens = (
1067
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1068
+ )
1069
+ cache_position = torch.arange(
1070
+ past_seen_tokens,
1071
+ past_seen_tokens + inputs_embeds.shape[1],
1072
+ device=inputs_embeds.device,
1073
+ )
1074
+ if position_ids is None:
1075
+ position_ids = cache_position.unsqueeze(0)
1076
+
1077
+ causal_mask = self._update_causal_mask(
1078
+ attention_mask,
1079
+ inputs_embeds,
1080
+ cache_position,
1081
+ past_key_values,
1082
+ output_attentions,
1083
+ )
1084
+
1085
+ hidden_states = inputs_embeds
1086
+
1087
+ # decoder layers
1088
+ all_hidden_states = () if output_hidden_states else None
1089
+ all_self_attns = () if output_attentions else None
1090
+ next_decoder_cache = None
1091
+
1092
+ for decoder_layer in self.h:
1093
+ if output_hidden_states:
1094
+ all_hidden_states += (hidden_states,)
1095
+
1096
+ if self.gradient_checkpointing and self.training:
1097
+ layer_outputs = self._gradient_checkpointing_func(
1098
+ decoder_layer.__call__,
1099
+ hidden_states,
1100
+ causal_mask,
1101
+ position_ids,
1102
+ output_attentions,
1103
+ use_cache,
1104
+ past_key_values,
1105
+ cache_position,
1106
+ )
1107
+ else:
1108
+ layer_outputs = decoder_layer(
1109
+ hidden_states,
1110
+ attention_mask=causal_mask,
1111
+ position_ids=position_ids,
1112
+ past_key_value=past_key_values,
1113
+ output_attentions=output_attentions,
1114
+ use_cache=use_cache,
1115
+ cache_position=cache_position,
1116
+ )
1117
+
1118
+ hidden_states = layer_outputs[0]
1119
+
1120
+ if use_cache:
1121
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1122
+
1123
+ if output_attentions:
1124
+ all_self_attns += (layer_outputs[1],)
1125
+
1126
+ # add hidden states from the last decoder layer
1127
+ if output_hidden_states:
1128
+ all_hidden_states += (hidden_states,)
1129
+
1130
+ next_cache = None
1131
+ if use_cache:
1132
+ next_cache = (
1133
+ next_decoder_cache.to_legacy_cache()
1134
+ if use_legacy_cache
1135
+ else next_decoder_cache
1136
+ )
1137
+ if not return_dict:
1138
+ return tuple(
1139
+ v
1140
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1141
+ if v is not None
1142
+ )
1143
+ return BaseModelOutputWithPast(
1144
+ last_hidden_state=hidden_states,
1145
+ past_key_values=next_cache,
1146
+ hidden_states=all_hidden_states,
1147
+ attentions=all_self_attns,
1148
+ )
1149
+
1150
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
+ def _update_causal_mask(
1152
+ self,
1153
+ attention_mask: torch.Tensor,
1154
+ input_tensor: torch.Tensor,
1155
+ cache_position: torch.Tensor,
1156
+ past_key_values: Cache,
1157
+ output_attentions: bool,
1158
+ ):
1159
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
+
1164
+ if self.config._attn_implementation == "flash_attention_2":
1165
+ if attention_mask is not None and 0.0 in attention_mask:
1166
+ return attention_mask
1167
+ return None
1168
+
1169
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
+ # to infer the attention mask.
1172
+ past_seen_tokens = (
1173
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1174
+ )
1175
+ using_static_cache = isinstance(past_key_values, StaticCache)
1176
+
1177
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
+ if (
1179
+ self.config._attn_implementation == "sdpa"
1180
+ and not using_static_cache
1181
+ and not output_attentions
1182
+ ):
1183
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
+ attention_mask,
1185
+ inputs_embeds=input_tensor,
1186
+ past_key_values_length=past_seen_tokens,
1187
+ is_training=self.training,
1188
+ ):
1189
+ return None
1190
+
1191
+ dtype, device = input_tensor.dtype, input_tensor.device
1192
+ min_dtype = torch.finfo(dtype).min
1193
+ sequence_length = input_tensor.shape[1]
1194
+ if using_static_cache:
1195
+ target_length = past_key_values.get_max_length()
1196
+ else:
1197
+ target_length = (
1198
+ attention_mask.shape[-1]
1199
+ if isinstance(attention_mask, torch.Tensor)
1200
+ else past_seen_tokens + sequence_length + 1
1201
+ )
1202
+
1203
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
+ attention_mask,
1206
+ sequence_length=sequence_length,
1207
+ target_length=target_length,
1208
+ dtype=dtype,
1209
+ device=device,
1210
+ min_dtype=min_dtype,
1211
+ cache_position=cache_position,
1212
+ batch_size=input_tensor.shape[0],
1213
+ )
1214
+
1215
+ if (
1216
+ self.config._attn_implementation == "sdpa"
1217
+ and attention_mask is not None
1218
+ and attention_mask.device.type == "cuda"
1219
+ and not output_attentions
1220
+ ):
1221
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1224
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1225
+ causal_mask, min_dtype
1226
+ )
1227
+
1228
+ return causal_mask
1229
+
1230
+
1231
+ class CausalLMHead(nn.Module):
1232
+ """Causal Language Modeling head. Simplified version."""
1233
+
1234
+ def __init__(self, config):
1235
+ super().__init__()
1236
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
+ self.linear = nn.Linear(config.hidden_size, config.vocab_size)
1238
+
1239
+ def forward(self, hidden_states):
1240
+ return self.linear(self.ln(hidden_states))
1241
+
1242
+
1243
+ class PhiForCausalLM(PhiPreTrainedModel):
1244
+
1245
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
+ def __init__(self, config):
1247
+ super().__init__(config)
1248
+ self.transformer = PhiModel(config)
1249
+ self.vocab_size = config.vocab_size
1250
+ self.lm_head = CausalLMHead(config)
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1256
+ def get_input_embeddings(self):
1257
+ return self.transformer.embd.wte
1258
+
1259
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
+ def set_input_embeddings(self, value):
1261
+ self.transformer.embd.wte = value
1262
+
1263
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
+ def get_output_embeddings(self):
1265
+ return self.lm_head.linear
1266
+
1267
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1268
+ def set_output_embeddings(self, new_embeddings):
1269
+ self.lm_head.linear = new_embeddings
1270
+
1271
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1272
+ def set_decoder(self, decoder):
1273
+ self.model = decoder
1274
+
1275
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1276
+ def get_decoder(self):
1277
+ return self.model
1278
+
1279
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
+ @replace_return_docstrings(
1281
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
+ )
1283
+ def forward(
1284
+ self,
1285
+ input_ids: torch.LongTensor = None,
1286
+ attention_mask: Optional[torch.Tensor] = None,
1287
+ position_ids: Optional[torch.LongTensor] = None,
1288
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1290
+ labels: Optional[torch.LongTensor] = None,
1291
+ use_cache: Optional[bool] = None,
1292
+ output_attentions: Optional[bool] = None,
1293
+ output_hidden_states: Optional[bool] = None,
1294
+ return_dict: Optional[bool] = None,
1295
+ cache_position: Optional[torch.LongTensor] = None,
1296
+ num_logits_to_keep: int = 0,
1297
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
+ r"""
1299
+ Args:
1300
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
+
1305
+ num_logits_to_keep (`int`, *optional*):
1306
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
+
1310
+ Returns:
1311
+
1312
+ Example:
1313
+
1314
+ ```python
1315
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1316
+
1317
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1318
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1319
+
1320
+ >>> prompt = "This is an example script ."
1321
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
+
1323
+ >>> # Generate
1324
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1327
+ ```"""
1328
+
1329
+ output_attentions = (
1330
+ output_attentions
1331
+ if output_attentions is not None
1332
+ else self.config.output_attentions
1333
+ )
1334
+ output_hidden_states = (
1335
+ output_hidden_states
1336
+ if output_hidden_states is not None
1337
+ else self.config.output_hidden_states
1338
+ )
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1344
+ outputs = self.transformer(
1345
+ input_ids=input_ids,
1346
+ attention_mask=attention_mask,
1347
+ position_ids=position_ids,
1348
+ past_key_values=past_key_values,
1349
+ inputs_embeds=inputs_embeds,
1350
+ use_cache=use_cache,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ cache_position=cache_position,
1355
+ )
1356
+
1357
+ hidden_states = outputs[0]
1358
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
+
1360
+ loss = None
1361
+ if labels is not None:
1362
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
+ logits = logits.float()
1364
+ # Shift so that tokens < n predict n
1365
+ shift_logits = logits[..., :-1, :].contiguous()
1366
+ shift_labels = labels[..., 1:].contiguous()
1367
+ # Flatten the tokens
1368
+ loss_fct = CrossEntropyLoss()
1369
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
+ shift_labels = shift_labels.view(-1)
1371
+ # Enable model parallelism
1372
+ shift_labels = shift_labels.to(shift_logits.device)
1373
+ loss = loss_fct(shift_logits, shift_labels)
1374
+
1375
+ if not return_dict:
1376
+ output = (logits,) + outputs[1:]
1377
+ return (loss,) + output if loss is not None else output
1378
+
1379
+ return CausalLMOutputWithPast(
1380
+ loss=loss,
1381
+ logits=logits,
1382
+ past_key_values=outputs.past_key_values,
1383
+ hidden_states=outputs.hidden_states,
1384
+ attentions=outputs.attentions,
1385
+ )
1386
+
1387
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1388
+ def prepare_inputs_for_generation(
1389
+ self,
1390
+ input_ids,
1391
+ inputs_embeds=None,
1392
+ past_key_values=None,
1393
+ attention_mask=None,
1394
+ cache_position=None,
1395
+ position_ids=None,
1396
+ use_cache=True,
1397
+ num_logits_to_keep=0,
1398
+ **kwargs,
1399
+ ):
1400
+ assert inputs_embeds is not None, "inputs_embeds is required"
1401
+
1402
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1403
+ if past_key_values is not None:
1404
+ # When doing custom decoding for object detection, we don't update input_ids.
1405
+ # So we will slice `inputs_embeds`` instead.
1406
+ if input_ids.shape[1] == 0:
1407
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
1408
+ else:
1409
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1410
+
1411
+ if attention_mask is not None and position_ids is None:
1412
+ # create position_ids on the fly for batch generation
1413
+ position_ids = attention_mask.long().cumsum(-1) - 1
1414
+ position_ids.masked_fill_(attention_mask == 0, 1)
1415
+ if past_key_values:
1416
+ if input_ids.shape[1] == 0:
1417
+ position_ids = position_ids[:, -inputs_embeds.shape[1] :]
1418
+ else:
1419
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1420
+
1421
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
1422
+ # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various
1423
+ # stride during the decoding. Here, simply using `.contiguous()` is not sufficient as
1424
+ # in the batch size = 1 case, `position_ids` is already contiguous but with varying
1425
+ # stride which retriggers a capture.
1426
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1427
+
1428
+ if cache_position[0] == 0:
1429
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1430
+ else:
1431
+ # The clone here is for the same reason as for `position_ids`.
1432
+ if past_key_values is not None and input_ids.shape[1] == 0:
1433
+ model_inputs = {
1434
+ "input_ids": None,
1435
+ "inputs_embeds": inputs_embeds.clone(
1436
+ memory_format=torch.contiguous_format
1437
+ ),
1438
+ }
1439
+ else:
1440
+ model_inputs = {
1441
+ "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1442
+ "inputs_embeds": None,
1443
+ }
1444
+
1445
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1446
+ if model_inputs["inputs_embeds"] is not None:
1447
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1448
+ device = model_inputs["inputs_embeds"].device
1449
+ else:
1450
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1451
+ device = model_inputs["input_ids"].device
1452
+
1453
+ dtype = self.lm_head.weight.dtype
1454
+ min_dtype = torch.finfo(dtype).min
1455
+
1456
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1457
+ attention_mask,
1458
+ sequence_length=sequence_length,
1459
+ target_length=past_key_values.get_max_length(),
1460
+ dtype=dtype,
1461
+ device=device,
1462
+ min_dtype=min_dtype,
1463
+ cache_position=cache_position,
1464
+ batch_size=batch_size,
1465
+ )
1466
+
1467
+ model_inputs.update(
1468
+ {
1469
+ "position_ids": position_ids,
1470
+ "cache_position": cache_position,
1471
+ "past_key_values": past_key_values,
1472
+ "use_cache": use_cache,
1473
+ "attention_mask": attention_mask,
1474
+ "num_logits_to_keep": num_logits_to_keep,
1475
+ }
1476
+ )
1477
+ return model_inputs
moondream/hf/moondream.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Optional
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import PreTrainedModel
6
+
7
+ from .configuration_moondream import MoondreamConfig, PhiConfig
8
+ from .modeling_phi import PhiForCausalLM
9
+ from .region_model import RegionModel
10
+ from .vision_encoder import VisionEncoder
11
+
12
+
13
+ class Moondream(PreTrainedModel):
14
+ config_class = MoondreamConfig
15
+ _supports_flash_attn_2 = True
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.vision_encoder = VisionEncoder(
20
+ use_flash_attn=config._attn_implementation == "flash_attention_2"
21
+ )
22
+ self.region_model = RegionModel()
23
+
24
+ if type(config.text_config) == dict:
25
+ phi_config = PhiConfig(
26
+ **config.text_config, attn_implementation=config._attn_implementation
27
+ )
28
+ else:
29
+ phi_config = config.text_config
30
+ self.text_model = PhiForCausalLM(phi_config)
31
+
32
+ @property
33
+ def device(self):
34
+ return self.text_model.device
35
+
36
+ def encode_image(self, image):
37
+ with torch.no_grad():
38
+ return self.vision_encoder(image)
39
+
40
+ def input_embeds(self, prompt, image_embeds, tokenizer):
41
+ def _tokenize(txt):
42
+ return tokenizer(
43
+ txt, return_tensors="pt", add_special_tokens=False
44
+ ).input_ids.to(self.device)
45
+
46
+ text_emb = self.text_model.get_input_embeddings()
47
+
48
+ # Add BOS token
49
+ embeds = []
50
+ embeds.append(
51
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
52
+ )
53
+
54
+ if "<image>" not in prompt:
55
+ embeds.append(text_emb(_tokenize(prompt)))
56
+ else:
57
+ assert prompt.count("<image>") == 1
58
+ before, after = prompt.split("<image>")
59
+ if len(before) > 0:
60
+ embeds.append(text_emb(_tokenize(before)))
61
+ embeds.append(image_embeds.to(self.device))
62
+ if len(after) > 0:
63
+ embeds.append(text_emb(_tokenize(after)))
64
+
65
+ return torch.cat(embeds, dim=1)
66
+
67
+ def get_input_embeddings(self):
68
+ return self.text_model.get_input_embeddings()
69
+
70
+ def generate(
71
+ self,
72
+ image_embeds,
73
+ prompt,
74
+ tokenizer,
75
+ max_new_tokens=128,
76
+ **kwargs,
77
+ ):
78
+ generate_config = {
79
+ "eos_token_id": tokenizer.eos_token_id,
80
+ "bos_token_id": tokenizer.bos_token_id,
81
+ "pad_token_id": tokenizer.bos_token_id,
82
+ "max_new_tokens": max_new_tokens,
83
+ **kwargs,
84
+ }
85
+
86
+ with torch.no_grad():
87
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
88
+ attention_mask = torch.ones(
89
+ (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device
90
+ )
91
+ output_ids = self.text_model.generate(
92
+ inputs_embeds=inputs_embeds,
93
+ attention_mask=attention_mask,
94
+ **generate_config,
95
+ )
96
+
97
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
98
+
99
+ # Note: Not ready for use yet, intended for September release.
100
+ def caption(
101
+ self,
102
+ images: List[Image.Image],
103
+ tokenizer,
104
+ length: Optional[Literal["short"]] = None,
105
+ **kwargs,
106
+ ):
107
+ image_embeds = self.encode_image(images)
108
+
109
+ templated_prompts = [
110
+ f"<image>\n\n{'Short caption' if length == 'short' else 'Caption'}:"
111
+ for _ in images
112
+ ]
113
+ inputs_embeds = torch.stack(
114
+ [
115
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
116
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
117
+ ]
118
+ )
119
+ attention_mask = torch.ones(
120
+ (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device
121
+ )
122
+
123
+ generate_config = {
124
+ "eos_token_id": tokenizer.eos_token_id,
125
+ "bos_token_id": tokenizer.bos_token_id,
126
+ "pad_token_id": tokenizer.bos_token_id,
127
+ "repetition_penalty": 1.2,
128
+ "max_new_tokens": 512,
129
+ **kwargs,
130
+ }
131
+
132
+ with torch.no_grad():
133
+ output_ids = self.text_model.generate(
134
+ inputs_embeds=inputs_embeds,
135
+ attention_mask=attention_mask,
136
+ **generate_config,
137
+ )
138
+
139
+ return [
140
+ x.strip()
141
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
142
+ ]
143
+
144
+ def answer_question(
145
+ self,
146
+ image_embeds,
147
+ question,
148
+ tokenizer,
149
+ chat_history="",
150
+ result_queue=None,
151
+ max_new_tokens=256,
152
+ **kwargs,
153
+ ):
154
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
155
+ answer = self.generate(
156
+ image_embeds,
157
+ prompt,
158
+ tokenizer=tokenizer,
159
+ max_new_tokens=max_new_tokens,
160
+ **kwargs,
161
+ )[0]
162
+ cleaned_answer = answer.strip()
163
+
164
+ # Use the result_queue to pass the result if it is provided
165
+ if result_queue:
166
+ result_queue.put(cleaned_answer)
167
+ else:
168
+ return cleaned_answer
169
+
170
+ def batch_answer(
171
+ self,
172
+ images,
173
+ prompts,
174
+ tokenizer,
175
+ **kwargs,
176
+ ):
177
+ image_embeds = self.encode_image(images)
178
+
179
+ templated_prompts = [
180
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
181
+ ]
182
+ prompt_embs = [
183
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
184
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
185
+ ]
186
+
187
+ bos_emb = prompt_embs[0][0]
188
+ max_len = max([p.shape[0] for p in prompt_embs])
189
+
190
+ inputs_embeds = torch.cat(
191
+ [
192
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
193
+ for p in prompt_embs
194
+ ],
195
+ dim=0,
196
+ )
197
+ attention_mask = torch.cat(
198
+ [
199
+ torch.cat(
200
+ [
201
+ torch.zeros(
202
+ 1,
203
+ max_len - p.shape[0],
204
+ device=self.device,
205
+ dtype=torch.long,
206
+ ),
207
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
208
+ ],
209
+ dim=1,
210
+ )
211
+ for p in prompt_embs
212
+ ],
213
+ dim=0,
214
+ )
215
+
216
+ generate_config = {
217
+ "eos_token_id": tokenizer.eos_token_id,
218
+ "bos_token_id": tokenizer.bos_token_id,
219
+ "pad_token_id": tokenizer.bos_token_id,
220
+ "max_new_tokens": 512,
221
+ **kwargs,
222
+ }
223
+
224
+ with torch.no_grad():
225
+ output_ids = self.text_model.generate(
226
+ inputs_embeds=inputs_embeds,
227
+ attention_mask=attention_mask,
228
+ **generate_config,
229
+ )
230
+
231
+ return [
232
+ x.strip()
233
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
234
+ ]
235
+
236
+ def detect(
237
+ self,
238
+ image: Image.Image,
239
+ query: str,
240
+ tokenizer,
241
+ max_objects=50,
242
+ ):
243
+ prompt = f"<image>\n\nDetect: {query}\n\n"
244
+ image_embeds = self.encode_image(image)
245
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
246
+ generate_config = {
247
+ "eos_token_id": tokenizer.eos_token_id,
248
+ "bos_token_id": tokenizer.bos_token_id,
249
+ "pad_token_id": tokenizer.bos_token_id,
250
+ "max_new_tokens": 1,
251
+ }
252
+
253
+ past_key_values = None
254
+ generated_boxes = []
255
+
256
+ with torch.no_grad():
257
+ while len(generated_boxes) < max_objects:
258
+ # x coordinate
259
+ attention_mask = torch.ones(
260
+ (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device
261
+ )
262
+ output = self.text_model.generate(
263
+ inputs_embeds=inputs_embeds,
264
+ past_key_values=past_key_values,
265
+ attention_mask=attention_mask,
266
+ return_dict_in_generate=True,
267
+ output_hidden_states=True,
268
+ **generate_config,
269
+ )
270
+ if output["sequences"][0][0].item() == tokenizer.eos_token_id:
271
+ break
272
+
273
+ x_coord_hidden = output["hidden_states"][0][-1][:, -1, :]
274
+ x_coord_logits = self.region_model.decode_coordinate(x_coord_hidden)
275
+ x_coord_decoded = (
276
+ torch.argmax(x_coord_logits, dim=-1).to(torch.float32) / 1024
277
+ ).to(torch.float16)
278
+ x_coord_encoded = self.region_model.encode_coordinate(
279
+ x_coord_decoded
280
+ ).unsqueeze(0)
281
+ inputs_embeds = torch.cat(
282
+ [inputs_embeds, x_coord_encoded.unsqueeze(0)], dim=1
283
+ )
284
+ past_key_values = output["past_key_values"]
285
+
286
+ # y coordinate
287
+ attention_mask = torch.ones(
288
+ (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device
289
+ )
290
+ output = self.text_model.generate(
291
+ inputs_embeds=inputs_embeds,
292
+ past_key_values=past_key_values,
293
+ attention_mask=attention_mask,
294
+ return_dict_in_generate=True,
295
+ output_hidden_states=True,
296
+ **generate_config,
297
+ )
298
+ y_coord_hidden = output["hidden_states"][0][-1][:, -1, :]
299
+ y_coord_logits = self.region_model.decode_coordinate(y_coord_hidden)
300
+ y_coord_decoded = (
301
+ torch.argmax(y_coord_logits, dim=-1).to(torch.float32) / 1024
302
+ ).to(torch.float16)
303
+ y_coord_encoded = self.region_model.encode_coordinate(
304
+ y_coord_decoded
305
+ ).unsqueeze(0)
306
+ inputs_embeds = torch.cat(
307
+ [inputs_embeds, y_coord_encoded.unsqueeze(0)], dim=1
308
+ )
309
+ past_key_values = output["past_key_values"]
310
+
311
+ # size (h and w)
312
+ attention_mask = torch.ones(
313
+ (inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device
314
+ )
315
+ output = self.text_model.generate(
316
+ inputs_embeds=inputs_embeds,
317
+ past_key_values=past_key_values,
318
+ attention_mask=attention_mask,
319
+ return_dict_in_generate=True,
320
+ output_hidden_states=True,
321
+ **generate_config,
322
+ )
323
+ size_hidden = output["hidden_states"][0][-1][:, -1, :]
324
+ size_logits = self.region_model.decode_size(size_hidden)
325
+ size_decoded = (
326
+ torch.argmax(size_logits, dim=-1).to(torch.float32) / 1024
327
+ ).to(torch.float16)
328
+ size_encoded = self.region_model.encode_size(size_decoded)
329
+ inputs_embeds = torch.cat(
330
+ [inputs_embeds, size_encoded.unsqueeze(0)], dim=1
331
+ )
332
+ past_key_values = output["past_key_values"]
333
+
334
+ x_center = x_coord_decoded[0].item()
335
+ y_center = y_coord_decoded[0].item()
336
+ w_center = size_decoded[0][0].item()
337
+ h_center = size_decoded[0][1].item()
338
+ x_min = max(x_center - w_center / 2, 0)
339
+ y_min = max(y_center - h_center / 2, 0)
340
+ x_max = min(x_center + w_center / 2, 1)
341
+ y_max = min(y_center + h_center / 2, 1)
342
+
343
+ generated_boxes.append(
344
+ {
345
+ "x_min": x_min,
346
+ "y_min": y_min,
347
+ "x_max": x_max,
348
+ "y_max": y_max,
349
+ }
350
+ )
351
+
352
+ return generated_boxes
moondream/hf/region_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .fourier_features import FourierFeatures
5
+
6
+
7
+ class MLP(nn.Module):
8
+
9
+ def __init__(
10
+ self,
11
+ in_features: int,
12
+ hidden_features: int = None,
13
+ out_features: int = None,
14
+ ) -> None:
15
+ super().__init__()
16
+ out_features = out_features or in_features
17
+ hidden_features = hidden_features or in_features * 4
18
+ self.fc1 = nn.Linear(in_features, hidden_features)
19
+ self.act = nn.GELU(approximate="tanh")
20
+ self.fc2 = nn.Linear(hidden_features, out_features)
21
+
22
+ torch.nn.init.kaiming_normal_(
23
+ self.fc1.weight, mode="fan_in", nonlinearity="relu"
24
+ )
25
+ torch.nn.init.kaiming_normal_(
26
+ self.fc2.weight, mode="fan_in", nonlinearity="relu"
27
+ )
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x = self.fc1(x)
31
+ x = self.act(x)
32
+ x = self.fc2(x)
33
+ return x
34
+
35
+
36
+ class RegionModel(nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+
40
+ self.coordinate_features = FourierFeatures(1, 256)
41
+ self.coordinate_encoder = nn.Linear(256, 2048)
42
+ self.size_features = FourierFeatures(2, 512)
43
+ self.size_encoder = nn.Linear(512, 2048)
44
+
45
+ self.coordinate_decoder = MLP(2048, 8192, 1024)
46
+ self.size_decoder = MLP(2048, 8192, 2048)
47
+
48
+ def encode_coordinate(self, coordinate):
49
+ return self.coordinate_encoder(self.coordinate_features(coordinate))
50
+
51
+ def encode_size(self, size):
52
+ return self.size_encoder(self.size_features(size))
53
+
54
+ def decode_coordinate(self, logit):
55
+ return self.coordinate_decoder(logit)
56
+
57
+ def decode_size(self, logit):
58
+ o = self.size_decoder(logit)
59
+ return o.view(-1, 2, 1024)
60
+
61
+ def encode(self, position, size):
62
+ c = self.encode_coordinate(position.view(2, 1)).view(2, 2048)
63
+ return torch.stack([c[0], c[1], self.encode_size(size)], dim=0)
64
+
65
+ def decode(self, position_logits, size_logits):
66
+ return (
67
+ self.decode_coordinate(position_logits),
68
+ self.decode_size(size_logits),
69
+ )
moondream/hf/util.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ LATEST_REVISION = "2024-08-26"
4
+
5
+
6
+ def detect_device():
7
+ """
8
+ Detects the appropriate device to run on, and return the device and dtype.
9
+ """
10
+ if torch.cuda.is_available():
11
+ return torch.device("cuda"), torch.float16
12
+ elif torch.backends.mps.is_available():
13
+ return torch.device("mps"), torch.float16
14
+ else:
15
+ return torch.device("cpu"), torch.float32
moondream/hf/vision_encoder.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import PIL
4
+ import PIL.Image
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch import nn
9
+ from torchvision.transforms.v2 import (
10
+ Compose,
11
+ InterpolationMode,
12
+ Normalize,
13
+ Resize,
14
+ ToDtype,
15
+ ToImage,
16
+ )
17
+ from transformers.utils import is_flash_attn_2_available
18
+
19
+ try:
20
+ if is_flash_attn_2_available():
21
+ from flash_attn.modules.mha import FlashSelfAttention
22
+ else:
23
+ FlashSelfAttention = None
24
+ except ImportError:
25
+ FlashSelfAttention = None
26
+
27
+
28
+ class Attention(nn.Module):
29
+
30
+ def __init__(self, dim, num_heads=16, use_flash_attn=False):
31
+ super().__init__()
32
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
33
+
34
+ self.num_heads = num_heads
35
+ self.head_dim = dim // num_heads
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3)
38
+ self.proj = nn.Linear(dim, dim)
39
+
40
+ if use_flash_attn and FlashSelfAttention is not None:
41
+ self.flash_attn = FlashSelfAttention()
42
+ else:
43
+ self.flash_attn = None
44
+
45
+ torch.nn.init.kaiming_normal_(
46
+ self.qkv.weight, mode="fan_in", nonlinearity="relu"
47
+ )
48
+ torch.nn.init.kaiming_normal_(
49
+ self.proj.weight, mode="fan_in", nonlinearity="relu"
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ if self.flash_attn is not None:
54
+ qkv = self.qkv(x)
55
+ qkv = rearrange(
56
+ qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
57
+ )
58
+ attn_output = self.flash_attn(qkv)
59
+ output = rearrange(attn_output, "... h d -> ... (h d)")
60
+ output = self.proj(output)
61
+ return output
62
+ else:
63
+ B, N, C = x.shape
64
+ qkv = (
65
+ self.qkv(x)
66
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
67
+ .permute(2, 0, 3, 1, 4)
68
+ )
69
+ q, k, v = qkv.unbind(0)
70
+
71
+ x = F.scaled_dot_product_attention(q, k, v)
72
+
73
+ x = x.transpose(1, 2).reshape(B, N, C)
74
+ x = self.proj(x)
75
+ return x
76
+
77
+
78
+ class VitBlock(nn.Module):
79
+
80
+ def __init__(self, embed_dim, use_flash_attn=False):
81
+ super().__init__()
82
+ self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
83
+ self.mlp = MLP(embed_dim, 4304)
84
+ self.norm1 = nn.LayerNorm(embed_dim)
85
+ self.norm2 = nn.LayerNorm(embed_dim)
86
+
87
+ def forward(self, x):
88
+ x = x + self.attn(self.norm1(x))
89
+ x = x + self.mlp(self.norm2(x))
90
+ return x
91
+
92
+
93
+ class VisionTransformer(nn.Module):
94
+
95
+ def __init__(self, use_flash_attn=False):
96
+ super().__init__()
97
+
98
+ embed_len = 729
99
+ embed_dim = 1152
100
+
101
+ self.patch_embed = LinearPatchEmbedding()
102
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
103
+ self.blocks = nn.Sequential(
104
+ *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
105
+ )
106
+ self.norm = nn.LayerNorm(embed_dim)
107
+
108
+ def forward(self, x):
109
+ x = self.patch_embed(x)
110
+ x = x + self.pos_embed
111
+ for block in self.blocks:
112
+ x = block(x)
113
+ return self.norm(x)
114
+
115
+
116
+ class EncoderWrapper(nn.Module):
117
+
118
+ def __init__(self, use_flash_attn=False):
119
+ super().__init__()
120
+ self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
121
+
122
+ def forward(self, x):
123
+ return self.model["visual"](x)
124
+
125
+
126
+ class LinearPatchEmbedding(nn.Module):
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+ self.linear = nn.Linear(588, 1152)
131
+
132
+ def forward(self, x):
133
+ b, c, hp1, wp2 = x.shape
134
+ p1, p2 = 14, 14
135
+ h, w = hp1 // p1, wp2 // p2
136
+ x = x.reshape(b, c, h, p1, w, p2)
137
+ x = x.permute(0, 2, 4, 1, 3, 5)
138
+ x = x.reshape(b, h * w, c * p1 * p2)
139
+
140
+ return self.linear(x)
141
+
142
+
143
+ class MLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ in_features: int,
147
+ hidden_features: int = None,
148
+ out_features: int = None,
149
+ ) -> None:
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = nn.GELU(approximate="tanh")
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+
157
+ torch.nn.init.kaiming_normal_(
158
+ self.fc1.weight, mode="fan_in", nonlinearity="relu"
159
+ )
160
+ torch.nn.init.kaiming_normal_(
161
+ self.fc2.weight, mode="fan_in", nonlinearity="relu"
162
+ )
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = self.fc1(x)
166
+ x = self.act(x)
167
+ x = self.fc2(x)
168
+ return x
169
+
170
+
171
+ class VisionProjection(nn.Module):
172
+ def __init__(self):
173
+ super().__init__()
174
+
175
+ image_embedding_dim = 1152
176
+ model_dim = 2048
177
+ hidden_dim = model_dim * 4
178
+
179
+ self.mlp = MLP(image_embedding_dim * 2, hidden_dim, model_dim)
180
+
181
+ @property
182
+ def device(self):
183
+ return self.mlp.fc1.weight.device
184
+
185
+ def forward(self, x):
186
+ return self.mlp(x)
187
+
188
+
189
+ def create_patches(image, patch_size=(378, 378)):
190
+ assert image.dim() == 3, "Image must be in CHW format"
191
+
192
+ _, height, width = image.shape # Channels, Height, Width
193
+ patch_height, patch_width = patch_size
194
+
195
+ if height == patch_height and width == patch_width:
196
+ return []
197
+
198
+ # Iterate over the image and create patches
199
+ patches = []
200
+ for i in range(0, height, patch_height):
201
+ row_patches = []
202
+ for j in range(0, width, patch_width):
203
+ patch = image[:, i : i + patch_height, j : j + patch_width]
204
+ row_patches.append(patch)
205
+ patches.append(torch.stack(row_patches))
206
+ return patches
207
+
208
+
209
+ class VisionEncoder(nn.Module):
210
+
211
+ def __init__(self, use_flash_attn=False):
212
+ super().__init__()
213
+
214
+ self.encoder = EncoderWrapper(use_flash_attn)
215
+ self.projection = VisionProjection()
216
+ self.supported_sizes = [(378, 378), (378, 756), (756, 378), (756, 756)]
217
+
218
+ @property
219
+ def device(self):
220
+ return self.projection.mlp.fc1.weight.device
221
+
222
+ @property
223
+ def dtype(self):
224
+ return self.projection.mlp.fc1.weight.dtype
225
+
226
+ def preprocess(self, image: PIL.Image.Image):
227
+ width, height = image.size
228
+ max_dim = max(width, height)
229
+ if max_dim < 512:
230
+ im_size = (378, 378)
231
+ else:
232
+ aspect_ratio = width / height
233
+ im_size = min(
234
+ self.supported_sizes,
235
+ key=lambda size: (
236
+ abs((size[1] / size[0]) - aspect_ratio),
237
+ abs(size[0] - width) + abs(size[1] - height),
238
+ ),
239
+ )
240
+
241
+ return Compose(
242
+ [
243
+ Resize(size=im_size, interpolation=InterpolationMode.BICUBIC),
244
+ ToImage(),
245
+ ToDtype(torch.float16, scale=True),
246
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
247
+ ]
248
+ )(image)
249
+
250
+ def forward(
251
+ self, images: Union[PIL.Image.Image, list[PIL.Image.Image], torch.Tensor]
252
+ ) -> torch.Tensor:
253
+ im_list = None
254
+ if isinstance(images, torch.Tensor):
255
+ # Input must have dimensions (B, C, H, W)
256
+ assert (
257
+ len(images.shape) == 4
258
+ ), "Tensor input must have dimensions (B, C, H, W)"
259
+ im_list = list(images)
260
+ elif isinstance(images, PIL.Image.Image):
261
+ im_list = [images]
262
+ elif isinstance(images, list):
263
+ im_list = images
264
+ else:
265
+ raise ValueError(
266
+ "Input must be a PIL image, list of PIL images, or a tensor"
267
+ )
268
+
269
+ # Preprocess unless the images are already tensors (indicating that
270
+ # they have already been preprocessed)
271
+ if not isinstance(im_list[0], torch.Tensor):
272
+ im_list = [self.preprocess(im.convert("RGB")) for im in im_list]
273
+
274
+ patches = [create_patches(im) for im in im_list]
275
+ flat_patches = [patch for image_patches in patches for patch in image_patches]
276
+
277
+ # Images may be variable size, and need to be resized to a common size after
278
+ # creating patches.
279
+ resized_images = [
280
+ F.interpolate(im.unsqueeze(0), size=(378, 378), mode="bilinear")
281
+ for im in im_list
282
+ ]
283
+
284
+ combined_images = torch.cat([*resized_images, *flat_patches], dim=0)
285
+ combined_images = combined_images.to(self.device, dtype=self.dtype)
286
+
287
+ combined_features = self.encoder(combined_images)
288
+
289
+ full_img_features = combined_features[: len(im_list)]
290
+ patch_features = (
291
+ combined_features[len(im_list) :].transpose(1, 2).view(-1, 1152, 27, 27)
292
+ )
293
+
294
+ # Reshape patch features back to their original structure
295
+ reshaped_patch_features = []
296
+ patch_idx = 0
297
+ for i, patch_set in enumerate(patches):
298
+ if len(patch_set) == 0:
299
+ reshaped_patch_features.append(
300
+ full_img_features[i].transpose(0, 1).view(1152, 27, 27)
301
+ )
302
+ else:
303
+ sample_features = []
304
+ for row_patches in patch_set:
305
+ row_len = len(row_patches)
306
+ row_features = patch_features[
307
+ patch_idx : patch_idx + row_len
308
+ ] # row_len, T, C
309
+ row_features = torch.cat(
310
+ list(row_features), dim=2
311
+ ) # T, C * row_len
312
+ patch_idx += row_len
313
+ sample_features.append(row_features)
314
+ sample_features = torch.cat(sample_features, dim=1)
315
+ sample_features = F.adaptive_avg_pool2d(
316
+ sample_features, output_size=(27, 27)
317
+ )
318
+ reshaped_patch_features.append(sample_features)
319
+ reshaped_patch_features = (
320
+ torch.stack(reshaped_patch_features).view(-1, 1152, 729).transpose(1, 2)
321
+ )
322
+
323
+ final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
324
+
325
+ return self.projection(final_features)
moondream/torch/layers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def gelu_approx(x):
10
+ return F.gelu(x, approximate="tanh")
11
+
12
+
13
+ @dataclass
14
+ class LinearWeights:
15
+ weight: torch.Tensor
16
+ bias: torch.Tensor
17
+
18
+
19
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
20
+ return F.linear(x, w.weight, w.bias)
21
+
22
+
23
+ @dataclass
24
+ class LayerNormWeights:
25
+ weight: torch.Tensor
26
+ bias: torch.Tensor
27
+
28
+
29
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
30
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
31
+
32
+
33
+ @dataclass
34
+ class MLPWeights:
35
+ fc1: LinearWeights
36
+ fc2: LinearWeights
37
+ act: Literal["gelu_approx"] = "gelu_approx"
38
+
39
+
40
+ def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
41
+ x = linear(x, w.fc1)
42
+ if w.act == "gelu_approx":
43
+ x = gelu_approx(x)
44
+ else:
45
+ raise NotImplementedError(f"Activation function {w.act} not implemented.")
46
+ x = linear(x, w.fc2)
47
+ return x
48
+
49
+
50
+ @dataclass
51
+ class AttentionWeights:
52
+ qkv: LinearWeights
53
+ proj: LinearWeights
54
+ n_heads: int
55
+
56
+
57
+ def attn(x: torch.Tensor, w: AttentionWeights) -> torch.Tensor:
58
+ bsz, q_len, d_model = x.shape
59
+ n_heads, head_dim = w.n_heads, d_model // w.n_heads
60
+
61
+ q, k, v = [
62
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
63
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
64
+ ]
65
+ out = F.scaled_dot_product_attention(q, k, v)
66
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
67
+ out = linear(out, w.proj)
68
+ return out
moondream/torch/rope.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def precompute_freqs_cis(
9
+ dim: int,
10
+ end: int,
11
+ theta: float = 10000.0,
12
+ use_scaled: bool = False,
13
+ dtype: torch.dtype = torch.float32,
14
+ ) -> torch.Tensor:
15
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
16
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
17
+ freqs = t * freqs.unsqueeze(0)
18
+ freqs = torch.exp(1j * freqs)
19
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
20
+
21
+
22
+ def apply_rotary_emb(
23
+ x: torch.Tensor,
24
+ freqs_cis: torch.Tensor,
25
+ position_ids: torch.Tensor,
26
+ interleave: bool = False,
27
+ ) -> torch.Tensor:
28
+ rot_dim = freqs_cis.shape[-2] * 2
29
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
30
+
31
+ if interleave:
32
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
33
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
34
+ else:
35
+ d_q = x_rot.shape[-1] // 2
36
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
37
+
38
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
39
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
40
+
41
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
42
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
43
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
44
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
45
+
46
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
moondream/torch/sample.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoTokenizer
8
+
9
+ from .rope import precompute_freqs_cis
10
+ from .text import lm_head, text_decoder, text_encoder
11
+ from .vision import encode_image
12
+ from .weights import load_from_pt, load_from_safetensors
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--image", "-i", type=str, required=True)
17
+ parser.add_argument("--prompt", "-p", type=str, required=True)
18
+ parser.add_argument("--model", "-m", type=str, required=True)
19
+ parser.add_argument("--config", "-c", type=str, default="{}")
20
+ parser.add_argument("--max-tokens", "-t", type=int, default=200)
21
+ parser.add_argument("--sampler", "-s", type=str, default="greedy")
22
+ args = parser.parse_args()
23
+
24
+ if torch.cuda.is_available():
25
+ torch.set_default_device("cuda")
26
+ elif torch.backends.mps.is_available():
27
+ torch.set_default_device("mps")
28
+
29
+ # Load config.
30
+ config = json.loads(args.config)
31
+ text_n_heads = config.get("text_n_heads", 32)
32
+
33
+ # Load model.
34
+ model_path = args.model
35
+ if not os.path.exists(model_path):
36
+ raise FileNotFoundError(f"Model not found at {model_path}")
37
+ if model_path.endswith(".pt"):
38
+ model = load_from_pt(model_path, **config)
39
+ elif model_path.endswith(".safetensors"):
40
+ model = load_from_safetensors(model_path, **config)
41
+ else:
42
+ raise ValueError(f"Invalid model format: {model_path}")
43
+
44
+ # Encode image.
45
+ image_path = args.image
46
+ if not os.path.exists(image_path):
47
+ raise FileNotFoundError(f"Image not found at {image_path}")
48
+ image = Image.open(image_path)
49
+ image = image.resize((378, 378))
50
+ image_tensor = encode_image(image, model.vision)
51
+
52
+ # Encode text, and create inputs_embeds.
53
+ tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
54
+ prompt = f"\n\nQuestion: {args.prompt}\n\nAnswer:"
55
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
56
+ input_ids = torch.cat([torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1)
57
+ inputs_embeds = text_encoder(input_ids, model.text)
58
+ inputs_embeds = torch.cat(
59
+ [
60
+ inputs_embeds[:, 0:1, :],
61
+ image_tensor.unsqueeze(0),
62
+ inputs_embeds[:, 1:, :],
63
+ ],
64
+ dim=1,
65
+ )
66
+
67
+ kv_cache = torch.empty(24, 2, 1, text_n_heads, 2048, 64, dtype=torch.float16)
68
+ freqs_cis = precompute_freqs_cis(32, 2048)
69
+ pos = 0
70
+
71
+ for _ in range(args.max_tokens):
72
+ with torch.no_grad():
73
+ hidden, kv_cache_update = text_decoder(
74
+ inputs_embeds, model.text, kv_cache[:, :, :, :, :pos, :], freqs_cis
75
+ )
76
+ logits = lm_head(hidden, model.text)
77
+ kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
78
+ kv_cache_update
79
+ )
80
+ pos += kv_cache_update.size(-2)
81
+
82
+ if args.sampler == "multinomial":
83
+ next_token = torch.multinomial(
84
+ torch.softmax(logits, dim=-1), num_samples=1
85
+ ).squeeze(0)
86
+ elif args.sampler == "greedy":
87
+ next_token = torch.argmax(logits, dim=-1)
88
+ else:
89
+ raise ValueError(f"Invalid sampler: {args.sampler}")
90
+
91
+ if next_token == tokenizer.eos_token_id:
92
+ print()
93
+ break
94
+
95
+ input_ids = next_token.unsqueeze(0)
96
+ inputs_embeds = text_encoder(input_ids, model.text)
97
+
98
+ output_text = tokenizer.batch_decode(input_ids)[0]
99
+ print(output_text, end="", flush=True)
moondream/torch/text.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ from .layers import layer_norm, linear, mlp
5
+ from .rope import apply_rotary_emb, precompute_freqs_cis
6
+ from .weights import AttentionWeights, TextModel, load_from_safetensors
7
+
8
+
9
+ def text_encoder(input_ids: torch.Tensor, w: TextModel):
10
+ return F.embedding(input_ids, w.wte)
11
+
12
+
13
+ def attn_mask(pos, seq_len):
14
+ """
15
+ Create an attention mask that aligns with the bottom right of the
16
+ attention matrix. For example, if q_len = 2 and kv_len = 5, we want the
17
+ following:
18
+
19
+ 1 1 1 1 0
20
+ 1 1 1 1 1
21
+
22
+ and not this, which is what we get by default if we just set is_causal.
23
+
24
+ 1 0 0 0 0
25
+ 1 1 0 0 0
26
+ """
27
+ mask = torch.ones(seq_len, pos + seq_len, dtype=torch.bool)
28
+ mask[:, pos:] = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
29
+ mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
30
+ return mask
31
+
32
+
33
+ def attn(
34
+ x: torch.Tensor,
35
+ w: AttentionWeights,
36
+ freqs_cis: torch.Tensor,
37
+ layer_kv_cache: torch.Tensor,
38
+ ):
39
+ bsz, q_len, d_model = x.shape
40
+ pos = 0 if layer_kv_cache is None else layer_kv_cache.shape[3]
41
+ n_heads, head_dim = w.n_heads, d_model // w.n_heads
42
+
43
+ q, k, v = [
44
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
45
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
46
+ ]
47
+
48
+ position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
49
+ q = apply_rotary_emb(q, freqs_cis, position_ids)
50
+ k = apply_rotary_emb(k, freqs_cis, position_ids)
51
+
52
+ k_, v_ = k, v
53
+ if layer_kv_cache is not None:
54
+ k = torch.cat([layer_kv_cache[0], k], dim=2)
55
+ v = torch.cat([layer_kv_cache[1], v], dim=2)
56
+
57
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask(pos, q_len)).to(
58
+ # This type conversion isn't needed when running in PyTorch directly, but the
59
+ # ONNX export runs attention in float32 because the attention mask is cast to
60
+ # float32.
61
+ x.dtype
62
+ )
63
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
64
+ out = linear(out, w.proj)
65
+ return out, torch.stack([k_, v_])
66
+
67
+
68
+ def text_decoder(
69
+ inputs_embeds: torch.Tensor,
70
+ w: TextModel,
71
+ kv_cache: torch.Tensor,
72
+ freqs_cis: torch.Tensor,
73
+ ):
74
+ hidden_BTC = inputs_embeds
75
+ new_kv_cache = [torch.empty(0)] * len(w.blocks)
76
+
77
+ for i, block in enumerate(w.blocks):
78
+ l_in = layer_norm(hidden_BTC, block.ln)
79
+ l_attn, new_kv_cache[i] = attn(l_in, block.attn, freqs_cis, kv_cache[i])
80
+ l_mlp = mlp(l_in, block.mlp)
81
+ hidden_BTC = hidden_BTC + l_attn + l_mlp
82
+
83
+ return hidden_BTC, torch.stack(new_kv_cache)
84
+
85
+
86
+ def lm_head(hidden_BTC: torch.Tensor, w: TextModel):
87
+ hidden_BC = hidden_BTC[:, -1, :]
88
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
89
+ logits = linear(hidden_BC, w.lm_head)
90
+ return logits
moondream/torch/vision.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from PIL import Image
6
+ from torch.nn import functional as F
7
+ from torchvision.transforms.v2 import InterpolationMode
8
+ from torchvision.transforms.v2.functional import normalize
9
+ from torchvision.transforms.v2.functional import resize as tv_resize
10
+ from torchvision.transforms.v2.functional import to_dtype, to_image
11
+
12
+ from .layers import attn, layer_norm, linear, mlp
13
+ from .weights import VisionModel, load_from_safetensors
14
+
15
+
16
+ def im_resize(
17
+ image: Image.Image,
18
+ size: List[int],
19
+ interpolation: InterpolationMode = InterpolationMode.BICUBIC,
20
+ ) -> Image.Image:
21
+ """
22
+ The 'resize' function from torchvision has bad type signatures.
23
+ it accepts both PIL images and torch tensors, but the type signature
24
+ only allows tensors.
25
+ """
26
+ return tv_resize(
27
+ image, # type: ignore
28
+ size,
29
+ InterpolationMode.BICUBIC,
30
+ )
31
+
32
+
33
+ def create_patches(
34
+ image: Image.Image, image_patch_size=378
35
+ ) -> Tuple[List[Image.Image], Tuple[int, int]]:
36
+ """
37
+ Split the given image into a variable number of patches depending upon its
38
+ resolution.
39
+ """
40
+ # Start off with the global patch.
41
+ patches = [im_resize(image, [image_patch_size, image_patch_size])]
42
+
43
+ # Find the closest resolution template.
44
+ res_templates = [(1, 2), (2, 1), (2, 2)]
45
+ im_width, im_height = image.size
46
+ max_dim = max(im_width, im_height)
47
+ if max_dim < image_patch_size * 1.4:
48
+ # If the image is already small, we just do a single patch that is a
49
+ # duplicate of the global patch. This creates a small amount of
50
+ # redundant computation now, but it is simpler and future-proofs us
51
+ # if/when we condition the vision encoder on the patch type.
52
+ res_template = (1, 1)
53
+ patches.append(patches[0])
54
+ else:
55
+ aspect_ratio = im_width / im_height
56
+ res_template = min(
57
+ res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio)
58
+ )
59
+ # TODO: Actually implement patching... just going to put in the global
60
+ # patch for now to make progress on other aspects.
61
+ patches.append(patches[0])
62
+
63
+ return patches, res_template
64
+
65
+
66
+ def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor:
67
+ patches, res_template = create_patches(image.convert("RGB"))
68
+ patches = torch.stack(
69
+ [
70
+ normalize(
71
+ to_dtype(to_image(patch), torch.float16, scale=True),
72
+ mean=[0.5, 0.5, 0.5],
73
+ std=[0.5, 0.5, 0.5],
74
+ )
75
+ for patch in patches
76
+ ]
77
+ )
78
+
79
+ outputs = vision_encoder(patches, weights)
80
+
81
+ # TODO: Merge sub-image patch outputs properly... for now we'll just assume
82
+ # that the global patch is repeated.
83
+ assert outputs.shape[0] == 2, "Expected single image patch."
84
+ outputs = torch.cat([outputs[0], outputs[1]], dim=-1)
85
+
86
+ return mlp(outputs, weights.proj_mlp)
87
+
88
+
89
+ def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel):
90
+ x = rearrange(
91
+ input_BCHW,
92
+ "b c (h p1) (w p2) -> b (h w) (c p1 p2)",
93
+ p1=w.patch_size,
94
+ p2=w.patch_size,
95
+ ) # B3HW -> B(HxW)(3xP1xP2), aka BTC
96
+
97
+ x = linear(x, w.patch_emb)
98
+ x = x + w.pos_emb
99
+ for block in w.blocks:
100
+ x = x + attn(layer_norm(x, block.ln1), block.attn)
101
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
102
+ x = layer_norm(x, w.post_ln)
103
+
104
+ return x
moondream/torch/weights.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import List, Callable
5
+
6
+ import safetensors
7
+ import torch
8
+
9
+ from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
10
+
11
+
12
+ @dataclass
13
+ class VisionBlock:
14
+ ln1: LayerNormWeights
15
+ attn: AttentionWeights
16
+ ln2: LayerNormWeights
17
+ mlp: MLPWeights
18
+
19
+
20
+ @dataclass
21
+ class VisionModel:
22
+ patch_size: int
23
+ patch_emb: LinearWeights
24
+ pos_emb: torch.Tensor
25
+ blocks: List[VisionBlock]
26
+ post_ln: LayerNormWeights
27
+ proj_mlp: MLPWeights
28
+
29
+
30
+ @dataclass
31
+ class TextBlock:
32
+ ln: LayerNormWeights
33
+ attn: AttentionWeights
34
+ mlp: MLPWeights
35
+
36
+
37
+ @dataclass
38
+ class TextModel:
39
+ wte: torch.Tensor
40
+ blocks: List[TextBlock]
41
+ post_ln: LayerNormWeights
42
+ lm_head: LinearWeights
43
+
44
+
45
+ @dataclass
46
+ class MoondreamModel:
47
+ vision: VisionModel
48
+ text: TextModel
49
+
50
+
51
+ @contextmanager
52
+ def safetensors_open(safetensors_file: str):
53
+ """
54
+ Simplify interfacing with safetensors files. Eliminates the need to ignore
55
+ type errors when using the `safe_open` function.
56
+ """
57
+ with safetensors.safe_open(
58
+ safetensors_file, framework="pt"
59
+ ) as st: # pyright: ignore
60
+
61
+ def get_tensor(name: str) -> torch.Tensor:
62
+ return st.get_tensor(name)
63
+
64
+ yield get_tensor
65
+
66
+
67
+ def load_model(
68
+ get_tensor: Callable[[str], torch.Tensor],
69
+ vision_blocks: int = 27,
70
+ text_blocks: int = 24,
71
+ vision_n_heads: int = 16,
72
+ text_n_heads: int = 32,
73
+ ) -> MoondreamModel:
74
+ ## Vision encoder
75
+ prefix = "vision_encoder.encoder.model.visual.patch_embed.linear"
76
+ patch_emb = LinearWeights(
77
+ weight=get_tensor(f"{prefix}.weight"), bias=get_tensor(f"{prefix}.bias")
78
+ )
79
+ patch_size = int(math.sqrt(patch_emb.weight.shape[1] // 3))
80
+ pos_emb = get_tensor("vision_encoder.encoder.model.visual.pos_embed")
81
+ post_ln = LayerNormWeights(
82
+ weight=get_tensor("vision_encoder.encoder.model.visual.norm.weight"),
83
+ bias=get_tensor("vision_encoder.encoder.model.visual.norm.bias"),
84
+ )
85
+ blocks = []
86
+ for i in range(vision_blocks):
87
+ prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
88
+ blocks.append(
89
+ VisionBlock(
90
+ ln1=LayerNormWeights(
91
+ weight=get_tensor(f"{prefix}.norm1.weight"),
92
+ bias=get_tensor(f"{prefix}.norm1.bias"),
93
+ ),
94
+ attn=AttentionWeights(
95
+ qkv=LinearWeights(
96
+ weight=get_tensor(f"{prefix}.attn.qkv.weight"),
97
+ bias=get_tensor(f"{prefix}.attn.qkv.bias"),
98
+ ),
99
+ proj=LinearWeights(
100
+ weight=get_tensor(f"{prefix}.attn.proj.weight"),
101
+ bias=get_tensor(f"{prefix}.attn.proj.bias"),
102
+ ),
103
+ n_heads=vision_n_heads,
104
+ ),
105
+ ln2=LayerNormWeights(
106
+ weight=get_tensor(f"{prefix}.norm2.weight"),
107
+ bias=get_tensor(f"{prefix}.norm2.bias"),
108
+ ),
109
+ mlp=MLPWeights(
110
+ fc1=LinearWeights(
111
+ weight=get_tensor(f"{prefix}.mlp.fc1.weight"),
112
+ bias=get_tensor(f"{prefix}.mlp.fc1.bias"),
113
+ ),
114
+ fc2=LinearWeights(
115
+ weight=get_tensor(f"{prefix}.mlp.fc2.weight"),
116
+ bias=get_tensor(f"{prefix}.mlp.fc2.bias"),
117
+ ),
118
+ ),
119
+ )
120
+ )
121
+ proj_mlp = MLPWeights(
122
+ fc1=LinearWeights(
123
+ weight=get_tensor("vision_encoder.projection.mlp.fc1.weight"),
124
+ bias=get_tensor("vision_encoder.projection.mlp.fc1.bias"),
125
+ ),
126
+ fc2=LinearWeights(
127
+ weight=get_tensor("vision_encoder.projection.mlp.fc2.weight"),
128
+ bias=get_tensor("vision_encoder.projection.mlp.fc2.bias"),
129
+ ),
130
+ act="gelu_approx",
131
+ )
132
+ vision = VisionModel(
133
+ patch_size=patch_size,
134
+ patch_emb=patch_emb,
135
+ pos_emb=pos_emb,
136
+ blocks=blocks,
137
+ post_ln=post_ln,
138
+ proj_mlp=proj_mlp,
139
+ )
140
+
141
+ ## Text decoder model
142
+ wte = get_tensor("text_model.transformer.embd.wte.weight")
143
+ post_ln = LayerNormWeights(
144
+ weight=get_tensor("text_model.lm_head.ln.weight"),
145
+ bias=get_tensor("text_model.lm_head.ln.bias"),
146
+ )
147
+ lm_head = LinearWeights(
148
+ weight=get_tensor("text_model.lm_head.linear.weight"),
149
+ bias=get_tensor("text_model.lm_head.linear.bias"),
150
+ )
151
+ blocks = []
152
+ for i in range(text_blocks):
153
+ prefix = f"text_model.transformer.h.{i}"
154
+ blocks.append(
155
+ TextBlock(
156
+ ln=LayerNormWeights(
157
+ weight=get_tensor(f"{prefix}.ln.weight"),
158
+ bias=get_tensor(f"{prefix}.ln.bias"),
159
+ ),
160
+ attn=AttentionWeights(
161
+ qkv=LinearWeights(
162
+ weight=get_tensor(f"{prefix}.mixer.Wqkv.weight"),
163
+ bias=get_tensor(f"{prefix}.mixer.Wqkv.bias"),
164
+ ),
165
+ proj=LinearWeights(
166
+ weight=get_tensor(f"{prefix}.mixer.out_proj.weight"),
167
+ bias=get_tensor(f"{prefix}.mixer.out_proj.bias"),
168
+ ),
169
+ n_heads=text_n_heads,
170
+ ),
171
+ mlp=MLPWeights(
172
+ fc1=LinearWeights(
173
+ weight=get_tensor(f"{prefix}.mlp.fc1.weight"),
174
+ bias=get_tensor(f"{prefix}.mlp.fc1.bias"),
175
+ ),
176
+ fc2=LinearWeights(
177
+ weight=get_tensor(f"{prefix}.mlp.fc2.weight"),
178
+ bias=get_tensor(f"{prefix}.mlp.fc2.bias"),
179
+ ),
180
+ act="gelu_approx",
181
+ ),
182
+ )
183
+ )
184
+ text = TextModel(wte=wte, blocks=blocks, post_ln=post_ln, lm_head=lm_head)
185
+
186
+ return MoondreamModel(vision=vision, text=text)
187
+
188
+
189
+ def load_from_safetensors(
190
+ safetensors_file: str,
191
+ vision_blocks: int = 27,
192
+ text_blocks: int = 24,
193
+ **kwargs,
194
+ ) -> MoondreamModel:
195
+ with safetensors_open(safetensors_file) as get_tensor:
196
+ return load_model(get_tensor, vision_blocks, text_blocks, **kwargs)
197
+
198
+
199
+ def load_from_pt(
200
+ pt_file: str,
201
+ vision_blocks: int = 27,
202
+ text_blocks: int = 24,
203
+ **kwargs,
204
+ ) -> MoondreamModel:
205
+ device = str(torch.empty(0).device)
206
+ tensors = torch.load(pt_file, map_location=device, weights_only=True)
207
+ tensors = {
208
+ k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
209
+ for k, v in tensors.items()
210
+ }
211
+ return load_model(lambda x: tensors[x], vision_blocks, text_blocks, **kwargs)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ weights = load_from_safetensors("model.safetensors")
216
+ print(weights)