x-lai commited on
Commit
968fffb
·
1 Parent(s): 0146331

support 4bit and 8bit inference

Browse files

Former-commit-id: 23930126323a0effb75929a5cc88c75c0d7bfbc2

README.md CHANGED
@@ -53,10 +53,15 @@ To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-
53
  ```
54
  CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
55
  ```
56
- To use `bfloat16` data type for inference:
57
  ```
58
  CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='bf16'
59
  ```
 
 
 
 
 
60
 
61
  After that, input the text prompt and then the image path. For example,
62
  ```
 
53
  ```
54
  CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
55
  ```
56
+ To use `bf16` or `fp16` data type for inference:
57
  ```
58
  CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='bf16'
59
  ```
60
+ To use `8bit` or `4bit` data type for inference:
61
+ ```
62
+ CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='fp16' --load_in_8bit
63
+ CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='fp16' --load_in_4bit
64
+ ```
65
 
66
  After that, input the text prompt and then the image path. For example,
67
  ```
chat.py CHANGED
@@ -17,19 +17,22 @@ def parse_args(args):
17
  parser = argparse.ArgumentParser(description='LISA chat')
18
  parser.add_argument('--version', default='xinlai/LISA-13B-llama2-v0')
19
  parser.add_argument('--vis_save_path', default='./vis_output', type=str)
20
- parser.add_argument('--precision', default='bf16', type=str, choices=['fp32', 'bf16'], help="precision for inference")
21
  parser.add_argument('--image-size', default=1024, type=int, help='image size')
22
  parser.add_argument('--model-max-length', default=512, type=int)
23
  parser.add_argument('--lora-r', default=-1, type=int)
24
  parser.add_argument('--vision-tower', default='openai/clip-vit-large-patch14', type=str)
25
  parser.add_argument('--local-rank', default=0, type=int, help='node rank')
 
 
26
  return parser.parse_args(args)
27
 
28
 
29
  def preprocess(x,
30
- pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
31
- pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
32
- img_size=1024) -> torch.Tensor:
 
33
  """Normalize pixel values and pad to a square input."""
34
  # Normalize colors
35
  x = (x - pixel_mean) / pixel_std
@@ -65,6 +68,8 @@ def main(args):
65
  args.version,
66
  args.lora_r,
67
  args.precision,
 
 
68
  )
69
 
70
  weight = {}
@@ -76,6 +81,14 @@ def main(args):
76
 
77
  if args.precision == 'bf16':
78
  model = model.bfloat16().cuda()
 
 
 
 
 
 
 
 
79
  else:
80
  model = model.float().cuda()
81
 
@@ -113,12 +126,16 @@ def main(args):
113
  original_size_list = [image.shape[:2]]
114
  if args.precision == 'bf16':
115
  images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().bfloat16()
 
 
116
  else:
117
  images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().float()
118
  images = transform.apply_image(image)
119
  resize_list = [images.shape[:2]]
120
  if args.precision == 'bf16':
121
  images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().bfloat16()
 
 
122
  else:
123
  images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().float()
124
 
 
17
  parser = argparse.ArgumentParser(description='LISA chat')
18
  parser.add_argument('--version', default='xinlai/LISA-13B-llama2-v0')
19
  parser.add_argument('--vis_save_path', default='./vis_output', type=str)
20
+ parser.add_argument('--precision', default='bf16', type=str, choices=['fp32', 'bf16', 'fp16'], help="precision for inference")
21
  parser.add_argument('--image-size', default=1024, type=int, help='image size')
22
  parser.add_argument('--model-max-length', default=512, type=int)
23
  parser.add_argument('--lora-r', default=-1, type=int)
24
  parser.add_argument('--vision-tower', default='openai/clip-vit-large-patch14', type=str)
25
  parser.add_argument('--local-rank', default=0, type=int, help='node rank')
26
+ parser.add_argument('--load_in_8bit', action='store_true', default=False)
27
+ parser.add_argument('--load_in_4bit', action='store_true', default=False)
28
  return parser.parse_args(args)
29
 
30
 
31
  def preprocess(x,
32
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
33
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
34
+ img_size=1024
35
+ ) -> torch.Tensor:
36
  """Normalize pixel values and pad to a square input."""
37
  # Normalize colors
38
  x = (x - pixel_mean) / pixel_std
 
68
  args.version,
69
  args.lora_r,
70
  args.precision,
71
+ load_in_8bit=args.load_in_8bit,
72
+ load_in_4bit=args.load_in_4bit,
73
  )
74
 
75
  weight = {}
 
81
 
82
  if args.precision == 'bf16':
83
  model = model.bfloat16().cuda()
84
+ elif args.precision == 'fp16':
85
+ import deepspeed
86
+ model_engine = deepspeed.init_inference(model=model,
87
+ dtype=torch.half,
88
+ replace_with_kernel_inject=True,
89
+ replace_method="auto",
90
+ )
91
+ model = model_engine.module
92
  else:
93
  model = model.float().cuda()
94
 
 
126
  original_size_list = [image.shape[:2]]
127
  if args.precision == 'bf16':
128
  images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().bfloat16()
129
+ elif args.precision == 'fp16':
130
+ images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().half()
131
  else:
132
  images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().float()
133
  images = transform.apply_image(image)
134
  resize_list = [images.shape[:2]]
135
  if args.precision == 'bf16':
136
  images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().bfloat16()
137
+ elif args.precision == 'fp16':
138
+ images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().half()
139
  else:
140
  images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().float()
141
 
model/LISA.py CHANGED
@@ -9,7 +9,7 @@ import torch.nn as nn
9
  import torch.nn.functional as F
10
  import transformers
11
 
12
- from transformers import LlamaForCausalLM, CLIPVisionModel
13
  from peft import (
14
  LoraConfig,
15
  get_peft_model,
@@ -49,6 +49,8 @@ class LISA(nn.Module):
49
  llm_version,
50
  lora_r,
51
  precision,
 
 
52
  lora_target_modules=['q_proj', 'v_proj'],
53
  lora_alpha=16,
54
  lora_dropout=0.05,
@@ -69,6 +71,20 @@ class LISA(nn.Module):
69
  num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
70
  if precision == "bf16":
71
  self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
  self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True)
74
 
@@ -85,6 +101,8 @@ class LISA(nn.Module):
85
  if vision_tower.device.type == 'meta':
86
  if precision == 'bf16':
87
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank)
 
 
88
  else:
89
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank)
90
  self.lm.get_model().vision_tower[0] = vision_tower
@@ -92,6 +110,8 @@ class LISA(nn.Module):
92
 
93
  if precision == "bf16":
94
  vision_tower.to(device='cuda', dtype=torch.bfloat16)
 
 
95
  else:
96
  vision_tower.to(device='cuda', dtype=torch.float32)
97
 
@@ -135,58 +155,59 @@ class LISA(nn.Module):
135
 
136
  def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None):
137
 
138
- outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True)
139
- output_hidden_states = outputs.hidden_states[-1]
140
- output_ids = outputs.sequences
141
-
142
- seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx)
143
 
144
- last_embedding = None
145
- last_output_logit = None
146
- hidden_states = []
147
 
148
- assert len(self.text_hidden_fcs) == 1
149
- hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
 
150
 
151
- last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
152
- pred_embeddings = last_hidden_state[seg_token_mask]
153
-
154
- seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ]
155
- seg_token_offset = seg_token_counts.cumsum(-1)
156
- seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
157
-
158
- pred_embeddings_ = []
159
- for i in range(len(seg_token_offset)-1):
160
- start_i, end_i = seg_token_offset[i], seg_token_offset[i+1]
161
- pred_embeddings_.append(pred_embeddings[start_i: end_i])
162
- pred_embeddings = pred_embeddings_
163
-
164
- image_embeddings = self.get_visual_embs(images)
165
-
166
- multimask_output = False
167
- pred_masks = []
168
- for i in range(len(pred_embeddings)):
169
- sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
170
- points=None,
171
- boxes=None,
172
- masks=None,
173
- text_embeds=pred_embeddings[i].unsqueeze(1),
174
- )
175
-
176
- sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
177
- low_res_masks, iou_predictions = self.visual_model.mask_decoder(
178
- image_embeddings=image_embeddings[i].unsqueeze(0),
179
- image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
180
- sparse_prompt_embeddings=sparse_embeddings,
181
- dense_prompt_embeddings=dense_embeddings,
182
- multimask_output=multimask_output,
183
- )
184
 
185
- pred_mask = self.visual_model.postprocess_masks(
186
- low_res_masks,
187
- input_size=resize_list[i],
188
- original_size=original_size_list[i],
189
- )
190
- pred_masks.append(pred_mask[:, 0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  return output_ids, pred_masks
 
9
  import torch.nn.functional as F
10
  import transformers
11
 
12
+ from transformers import LlamaForCausalLM, CLIPVisionModel, BitsAndBytesConfig
13
  from peft import (
14
  LoraConfig,
15
  get_peft_model,
 
49
  llm_version,
50
  lora_r,
51
  precision,
52
+ load_in_4bit=False,
53
+ load_in_8bit=False,
54
  lora_target_modules=['q_proj', 'v_proj'],
55
  lora_alpha=16,
56
  lora_dropout=0.05,
 
71
  num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
72
  if precision == "bf16":
73
  self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True)
74
+ elif precision == "fp16":
75
+ if load_in_4bit:
76
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_4bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto',
77
+ quantization_config=BitsAndBytesConfig(
78
+ load_in_4bit=True,
79
+ bnb_4bit_compute_dtype=torch.float16,
80
+ bnb_4bit_use_double_quant=True,
81
+ bnb_4bit_quant_type='nf4'
82
+ )
83
+ )
84
+ elif load_in_8bit:
85
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_8bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto')
86
+ else:
87
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.half, cache_dir=None, low_cpu_mem_usage=True)
88
  else:
89
  self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True)
90
 
 
101
  if vision_tower.device.type == 'meta':
102
  if precision == 'bf16':
103
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank)
104
+ elif precision == 'fp16':
105
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.half, low_cpu_mem_usage=True).cuda(local_rank)
106
  else:
107
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank)
108
  self.lm.get_model().vision_tower[0] = vision_tower
 
110
 
111
  if precision == "bf16":
112
  vision_tower.to(device='cuda', dtype=torch.bfloat16)
113
+ elif precision == "fp16":
114
+ vision_tower.to(device='cuda', dtype=torch.half)
115
  else:
116
  vision_tower.to(device='cuda', dtype=torch.float32)
117
 
 
155
 
156
  def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None):
157
 
158
+ with torch.no_grad():
159
+ outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True)
160
+ output_hidden_states = outputs.hidden_states[-1]
161
+ output_ids = outputs.sequences
 
162
 
163
+ seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx)
 
 
164
 
165
+ last_embedding = None
166
+ last_output_logit = None
167
+ hidden_states = []
168
 
169
+ assert len(self.text_hidden_fcs) == 1
170
+ hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
171
+
172
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
173
+ pred_embeddings = last_hidden_state[seg_token_mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ]
176
+ seg_token_offset = seg_token_counts.cumsum(-1)
177
+ seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
178
+
179
+ pred_embeddings_ = []
180
+ for i in range(len(seg_token_offset)-1):
181
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i+1]
182
+ pred_embeddings_.append(pred_embeddings[start_i: end_i])
183
+ pred_embeddings = pred_embeddings_
184
+
185
+ image_embeddings = self.get_visual_embs(images)
186
+
187
+ multimask_output = False
188
+ pred_masks = []
189
+ for i in range(len(pred_embeddings)):
190
+ sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
191
+ points=None,
192
+ boxes=None,
193
+ masks=None,
194
+ text_embeds=pred_embeddings[i].unsqueeze(1),
195
+ )
196
+
197
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
198
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
199
+ image_embeddings=image_embeddings[i].unsqueeze(0),
200
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
201
+ sparse_prompt_embeddings=sparse_embeddings,
202
+ dense_prompt_embeddings=dense_embeddings,
203
+ multimask_output=multimask_output,
204
+ )
205
+
206
+ pred_mask = self.visual_model.postprocess_masks(
207
+ low_res_masks,
208
+ input_size=resize_list[i],
209
+ original_size=original_size_list[i],
210
+ )
211
+ pred_masks.append(pred_mask[:, 0])
212
 
213
  return output_ids, pred_masks
model/llava/model/llava.py CHANGED
@@ -63,6 +63,8 @@ class LlavaLlamaModel(LlamaModel):
63
  vision_tower.requires_grad_(False)
64
  if precision == 'bf16':
65
  vision_tower = vision_tower.to(torch.bfloat16)
 
 
66
  else:
67
  vision_tower = vision_tower.to(torch.float32)
68
 
 
63
  vision_tower.requires_grad_(False)
64
  if precision == 'bf16':
65
  vision_tower = vision_tower.to(torch.bfloat16)
66
+ elif precision == 'fp16':
67
+ vision_tower = vision_tower.to(torch.half)
68
  else:
69
  vision_tower = vision_tower.to(torch.float32)
70
 
model/segment_anything/modeling/image_encoder.py CHANGED
@@ -114,8 +114,13 @@ class ImageEncoderViT(nn.Module):
114
  for blk in self.blocks:
115
  x = blk(x)
116
 
117
- x = self.neck(x.permute(0, 3, 1, 2))
118
-
 
 
 
 
 
119
  return x
120
 
121
 
 
114
  for blk in self.blocks:
115
  x = blk(x)
116
 
117
+ dtype = x.dtype
118
+ if dtype == torch.float16: # prevent overflow
119
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
120
+ x = self.neck(x.permute(0, 3, 1, 2))
121
+ x = x.to(dtype)
122
+ else:
123
+ x = self.neck(x.permute(0, 3, 1, 2))
124
  return x
125
 
126