speed commited on
Commit
8bb5da7
·
verified ·
1 Parent(s): e8adde1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +91 -9
README.md CHANGED
@@ -11,19 +11,101 @@ This is a simple MLP trained on the MNIST dataset.
11
 
12
  Its primary use is to be a very simple reference model to test quantization.
13
 
14
- ## Inputs preprocessing
15
 
16
  The MNIST images must be normalized and flattened as follows:
17
 
18
  ```
19
- from torchvision import datasets, transforms
 
 
 
 
 
20
 
21
 
22
- transform=transforms.Compose([
23
- transforms.ToTensor(),
24
- transforms.Normalize((0.1307,), (0.3081,)),
25
- transforms.Lambda(lambda x: torch.flatten(x)),
26
- ])
27
- test_set = datasets.MNIST('../data', train=False, download=True,
28
- transform=transform)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ```
 
11
 
12
  Its primary use is to be a very simple reference model to test quantization.
13
 
14
+ ## How to use
15
 
16
  The MNIST images must be normalized and flattened as follows:
17
 
18
  ```
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+ import torch
21
+ from datasets import load_dataset
22
+ from torchvision import transforms
23
+ import util
24
+ from transformers import AutoModel
25
 
26
 
27
+ def build_multi_modal_prompt(
28
+ prompt: str,
29
+ image: torch.Tensor,
30
+ tokenizer: AutoTokenizer,
31
+ model: AutoModelForCausalLM,
32
+ vision_model: AutoModel,
33
+ ) -> torch.Tensor:
34
+ parts = prompt.split("<image>")
35
+ prefix = tokenizer(parts[0])
36
+ suffix = tokenizer(parts[1])
37
+ prefix_embedding = model.get_input_embeddings()(torch.tensor(prefix["input_ids"]))
38
+ suffix_embedding = model.get_input_embeddings()(torch.tensor(suffix["input_ids"]))
39
+ image_embedding = vision_model(image).to(torch.bfloat16).to(model.device)
40
+ multi_modal_embedding = torch.cat(
41
+ [prefix_embedding, image_embedding, suffix_embedding], dim=0
42
+ )
43
+ return multi_modal_embedding
44
+
45
+
46
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
47
+
48
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_id,
51
+ torch_dtype=torch.bfloat16,
52
+ device_map="auto",
53
+ )
54
+
55
+ vision_model = AutoModel.from_pretrained(
56
+ "speed/llava-mnist", trust_remote_code=True
57
+ )
58
+
59
+ terminators = [
60
+ tokenizer.eos_token_id,
61
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
62
+ ]
63
+
64
+ system_prompt = (
65
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|>"
66
+ )
67
+ user_prompt = "<|start_header_id|>user<|end_header_id|>"
68
+ question = "<image>What digit is this?"
69
+ assistant_prompt = "<|start_header_id|>assistant<|end_header_id|>"
70
+
71
+ prompt = system_prompt + user_prompt + question + assistant_prompt
72
+
73
+ ds = load_dataset("ylecun/mnist", split="test")
74
+
75
+
76
+ def transform_image(examples):
77
+ transform = transforms.Compose(
78
+ [
79
+ transforms.ToTensor(),
80
+ transforms.Normalize((0.1307,), (0.3081,)),
81
+ transforms.Lambda(lambda x: torch.flatten(x)),
82
+ ]
83
+ )
84
+ examples["pixel_values"] = [transform(image) for image in examples["image"]]
85
+
86
+ return examples
87
+
88
+ ds.set_transform(transform = transform_image)
89
+
90
+
91
+ model.eval()
92
+ vision_model.eval()
93
+
94
+ example = ds[0]
95
+
96
+ input_embeded = util.build_multi_modal_prompt(
97
+ prompt, example["pixel_values"].unsqueeze(0), tokenizer, model, vision_model
98
+ ).unsqueeze(0)
99
+ response = model.generate(
100
+ inputs_embeds=input_embeded,
101
+ max_new_tokens=20,
102
+ eos_token_id=terminators,
103
+ do_sample=True,
104
+ temperature=0.6,
105
+ top_p=0.9,
106
+ )
107
+ response = response[0]
108
+ print("Label:", example["label"])
109
+ answer = tokenizer.decode(response, skip_special_tokens=True)
110
+ print("Answer:", answer)
111
  ```