JHong commited on
Commit
e2d1b00
·
1 Parent(s): a45bc4e

Add application file

Browse files
app.py CHANGED
@@ -399,12 +399,12 @@ def build_demo(embed_mode):
399
  gr.Examples(
400
  examples=[
401
  [
402
- f"{cur_dir}/examples/extreme_ironing.jpg",
403
- "What is unusual about this image?",
404
  ],
405
  [
406
- f"{cur_dir}/examples/waterview.jpg",
407
- "What are the things I should be cautious about when I visit here?",
408
  ],
409
  ],
410
  inputs=[imagebox, textbox],
 
399
  gr.Examples(
400
  examples=[
401
  [
402
+ f"{cur_dir}/examples/CXR628_IM-2208-3001.png",
403
+ "Is there any indication of an enlarged heart based on this image?",
404
  ],
405
  [
406
+ f"{cur_dir}/examples/CXR22_IM-0810-1001.png",
407
+ "nCan you identify any signs of pulmonary fibrosis?",
408
  ],
409
  ],
410
  inputs=[imagebox, textbox],
examples/CXR22_IM-0810-1001.png ADDED
examples/CXR628_IM-2208-3001.png ADDED
llava/eval/eval_science_qa.py CHANGED
@@ -32,6 +32,7 @@ def get_pred_idx(prediction, choices, options):
32
  if prediction in options[:len(choices)]:
33
  return options.index(prediction)
34
  else:
 
35
  return random.choice(range(len(choices)))
36
 
37
 
@@ -55,16 +56,23 @@ if __name__ == "__main__":
55
 
56
  for prob_id, prob in split_problems.items():
57
  if prob_id not in predictions:
58
- continue
59
- pred = predictions[prob_id]
60
- pred_text = pred['text']
61
-
62
- pattern = re.compile(r'The answer is ([A-Z]).')
63
- res = pattern.findall(pred_text)
64
- if len(res) == 1:
65
- answer = res[0] # 'A', 'B', ...
66
  else:
67
- answer = "FAILED"
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  pred_idx = get_pred_idx(answer, prob['choices'], args.options)
70
 
@@ -87,7 +95,14 @@ if __name__ == "__main__":
87
 
88
  correct = len(results['correct'])
89
  total = len(results['correct']) + len(results['incorrect'])
90
- print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
 
 
 
 
 
 
 
91
 
92
  sqa_results['acc'] = correct / total * 100
93
  sqa_results['correct'] = correct
 
32
  if prediction in options[:len(choices)]:
33
  return options.index(prediction)
34
  else:
35
+ return -1
36
  return random.choice(range(len(choices)))
37
 
38
 
 
56
 
57
  for prob_id, prob in split_problems.items():
58
  if prob_id not in predictions:
59
+ pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60
+ pred_text = 'FAILED'
 
 
 
 
 
 
61
  else:
62
+ pred = predictions[prob_id]
63
+ pred_text = pred['text']
64
+
65
+ if pred_text in args.options:
66
+ answer = pred_text
67
+ elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68
+ answer = pred_text[0]
69
+ else:
70
+ pattern = re.compile(r'The answer is ([A-Z]).')
71
+ res = pattern.findall(pred_text)
72
+ if len(res) == 1:
73
+ answer = res[0] # 'A', 'B', ...
74
+ else:
75
+ answer = "FAILED"
76
 
77
  pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78
 
 
95
 
96
  correct = len(results['correct'])
97
  total = len(results['correct']) + len(results['incorrect'])
98
+
99
+ ###### IMG ######
100
+ multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101
+ multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102
+ multimodal_total = multimodal_correct + multimodal_incorrect
103
+ ###### IMG ######
104
+
105
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106
 
107
  sqa_results['acc'] = correct / total * 100
108
  sqa_results['correct'] = correct
llava/eval/model_vqa.py CHANGED
@@ -66,7 +66,7 @@ def eval_model(args):
66
  output_ids = model.generate(
67
  input_ids,
68
  images=image_tensor.unsqueeze(0).half().cuda(),
69
- do_sample=True,
70
  temperature=args.temperature,
71
  top_p=args.top_p,
72
  num_beams=args.num_beams,
 
66
  output_ids = model.generate(
67
  input_ids,
68
  images=image_tensor.unsqueeze(0).half().cuda(),
69
+ do_sample=True if args.temperature > 0 else False,
70
  temperature=args.temperature,
71
  top_p=args.top_p,
72
  num_beams=args.num_beams,
llava/eval/model_vqa_loader.py CHANGED
@@ -104,7 +104,6 @@ def eval_model(args):
104
  top_p=args.top_p,
105
  num_beams=args.num_beams,
106
  max_new_tokens=128,
107
- # max_length=64,
108
  use_cache=True)
109
 
110
  input_token_len = input_ids.shape[1]
@@ -124,7 +123,7 @@ def eval_model(args):
124
  "answer_id": ans_id,
125
  "model_id": model_name,
126
  "metadata": {}}) + "\n")
127
- ans_file.flush()
128
  ans_file.close()
129
 
130
  if __name__ == "__main__":
 
104
  top_p=args.top_p,
105
  num_beams=args.num_beams,
106
  max_new_tokens=128,
 
107
  use_cache=True)
108
 
109
  input_token_len = input_ids.shape[1]
 
123
  "answer_id": ans_id,
124
  "model_id": model_name,
125
  "metadata": {}}) + "\n")
126
+ # ans_file.flush()
127
  ans_file.close()
128
 
129
  if __name__ == "__main__":
llava/eval/model_vqa_science.py CHANGED
@@ -57,6 +57,10 @@ def eval_model(args):
57
  else:
58
  images = None
59
 
 
 
 
 
60
  conv = conv_templates[args.conv_mode].copy()
61
  conv.append_message(conv.roles[0], qs)
62
  conv.append_message(conv.roles[1], None)
@@ -72,8 +76,8 @@ def eval_model(args):
72
  output_ids = model.generate(
73
  input_ids,
74
  images=images,
75
- do_sample=True,
76
- temperature=0.2,
77
  max_new_tokens=1024,
78
  use_cache=True,
79
  stopping_criteria=stopping_criteria,
@@ -98,8 +102,8 @@ def eval_model(args):
98
  output_ids = model.generate(
99
  input_ids,
100
  images=images,
101
- do_sample=True,
102
- temperature=0.2,
103
  max_new_tokens=64,
104
  use_cache=True,
105
  stopping_criteria=[stopping_criteria])
@@ -135,7 +139,9 @@ if __name__ == "__main__":
135
  parser.add_argument("--conv-mode", type=str, default="llava_v0")
136
  parser.add_argument("--num-chunks", type=int, default=1)
137
  parser.add_argument("--chunk-idx", type=int, default=0)
 
138
  parser.add_argument("--answer-prompter", action="store_true")
 
139
  args = parser.parse_args()
140
 
141
  eval_model(args)
 
57
  else:
58
  images = None
59
 
60
+ if args.single_pred_prompt:
61
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
62
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
63
+
64
  conv = conv_templates[args.conv_mode].copy()
65
  conv.append_message(conv.roles[0], qs)
66
  conv.append_message(conv.roles[1], None)
 
76
  output_ids = model.generate(
77
  input_ids,
78
  images=images,
79
+ do_sample=True if args.temperature > 0 else False,
80
+ temperature=args.temperature,
81
  max_new_tokens=1024,
82
  use_cache=True,
83
  stopping_criteria=stopping_criteria,
 
102
  output_ids = model.generate(
103
  input_ids,
104
  images=images,
105
+ do_sample=True if args.temperature > 0 else False,
106
+ temperature=args.temperature,
107
  max_new_tokens=64,
108
  use_cache=True,
109
  stopping_criteria=[stopping_criteria])
 
139
  parser.add_argument("--conv-mode", type=str, default="llava_v0")
140
  parser.add_argument("--num-chunks", type=int, default=1)
141
  parser.add_argument("--chunk-idx", type=int, default=0)
142
+ parser.add_argument("--temperature", type=float, default=0.2)
143
  parser.add_argument("--answer-prompter", action="store_true")
144
+ parser.add_argument("--single-pred-prompt", action="store_true")
145
  args = parser.parse_args()
146
 
147
  eval_model(args)
llava/eval/summarize_gpt_review.py CHANGED
@@ -9,8 +9,10 @@ import argparse
9
  def parse_args():
10
  parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11
  parser.add_argument('-d', '--dir', default=None)
12
- parser.add_argument('-f', '--files', nargs='*', default=None)
13
- parser.add_argument('-i', '--ignore', nargs='*', default=None)
 
 
14
  return parser.parse_args()
15
 
16
 
@@ -20,19 +22,27 @@ if __name__ == '__main__':
20
  if args.ignore is not None:
21
  args.ignore = [int(x) for x in args.ignore]
22
 
23
- if args.files is not None and len(args.files) > 0:
24
  review_files = args.files
25
  else:
26
- review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))]
27
 
28
  for review_file in sorted(review_files):
29
  config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
 
 
 
 
 
 
 
 
30
  scores = defaultdict(list)
31
  print(config)
32
  with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
33
  for review_str in f:
34
  review = json.loads(review_str)
35
- if args.ignore is not None and review['question_id'] in args.ignore:
36
  continue
37
  if 'category' in review:
38
  scores[review['category']].append(review['tuple'])
@@ -46,5 +56,5 @@ if __name__ == '__main__':
46
  stats = np.asarray(v).mean(0).tolist()
47
  stats = [round(x, 3) for x in stats]
48
  # print(k, stats, round(stats[1]/stats[0]*100, 1))
49
- print(k, round(stats[1]/stats[0]*100, 1))
50
  print('=================================')
 
9
  def parse_args():
10
  parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11
  parser.add_argument('-d', '--dir', default=None)
12
+ parser.add_argument('-v', '--version', default=None)
13
+ parser.add_argument('-s', '--select', nargs='*', default=None)
14
+ parser.add_argument('-f', '--files', nargs='*', default=[])
15
+ parser.add_argument('-i', '--ignore', nargs='*', default=[])
16
  return parser.parse_args()
17
 
18
 
 
22
  if args.ignore is not None:
23
  args.ignore = [int(x) for x in args.ignore]
24
 
25
+ if len(args.files) > 0:
26
  review_files = args.files
27
  else:
28
+ review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29
 
30
  for review_file in sorted(review_files):
31
  config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32
+ if args.select is not None and any(x not in config for x in args.select):
33
+ continue
34
+ if '0613' in config:
35
+ version = '0613'
36
+ else:
37
+ version = '0314'
38
+ if args.version is not None and args.version != version:
39
+ continue
40
  scores = defaultdict(list)
41
  print(config)
42
  with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43
  for review_str in f:
44
  review = json.loads(review_str)
45
+ if review['question_id'] in args.ignore:
46
  continue
47
  if 'category' in review:
48
  scores[review['category']].append(review['tuple'])
 
56
  stats = np.asarray(v).mean(0).tolist()
57
  stats = [round(x, 3) for x in stats]
58
  # print(k, stats, round(stats[1]/stats[0]*100, 1))
59
+ print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60
  print('=================================')
llava/mm_utils.py CHANGED
@@ -77,23 +77,26 @@ class KeywordsStoppingCriteria(StoppingCriteria):
77
  def __init__(self, keywords, tokenizer, input_ids):
78
  self.keywords = keywords
79
  self.keyword_ids = []
 
80
  for keyword in keywords:
81
  cur_keyword_ids = tokenizer(keyword).input_ids
82
  if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
83
  cur_keyword_ids = cur_keyword_ids[1:]
 
 
84
  self.keyword_ids.append(torch.tensor(cur_keyword_ids))
85
  self.tokenizer = tokenizer
86
  self.start_len = input_ids.shape[1]
87
 
88
  def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
89
  assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
90
- offset = min(output_ids.shape[1] - self.start_len, 3)
91
  self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
92
  for keyword_id in self.keyword_ids:
93
- if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
94
  return True
95
  outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
96
  for keyword in self.keywords:
97
  if keyword in outputs:
98
  return True
99
- return False
 
77
  def __init__(self, keywords, tokenizer, input_ids):
78
  self.keywords = keywords
79
  self.keyword_ids = []
80
+ self.max_keyword_len = 0
81
  for keyword in keywords:
82
  cur_keyword_ids = tokenizer(keyword).input_ids
83
  if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
84
  cur_keyword_ids = cur_keyword_ids[1:]
85
+ if len(cur_keyword_ids) > self.max_keyword_len:
86
+ self.max_keyword_len = len(cur_keyword_ids)
87
  self.keyword_ids.append(torch.tensor(cur_keyword_ids))
88
  self.tokenizer = tokenizer
89
  self.start_len = input_ids.shape[1]
90
 
91
  def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
92
  assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
93
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
94
  self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
95
  for keyword_id in self.keyword_ids:
96
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
97
  return True
98
  outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
99
  for keyword in self.keywords:
100
  if keyword in outputs:
101
  return True
102
+ return False
llava/model/builder.py CHANGED
@@ -23,9 +23,8 @@ from llava.model import *
23
  from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
27
  kwargs = {"device_map": device_map}
28
- kwargs["offload_folder"] = "offload"
29
 
30
  if load_8bit:
31
  kwargs['load_in_8bit'] = True
@@ -138,9 +137,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
138
  vision_tower = model.get_vision_tower()
139
  if not vision_tower.is_loaded:
140
  vision_tower.load_model()
141
-
142
-
143
- vision_tower.to(device=model.device, dtype=torch.float16)
144
  image_processor = vision_tower.image_processor
145
 
146
  if hasattr(model.config, "max_sequence_length"):
 
23
  from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
  kwargs = {"device_map": device_map}
 
28
 
29
  if load_8bit:
30
  kwargs['load_in_8bit'] = True
 
137
  vision_tower = model.get_vision_tower()
138
  if not vision_tower.is_loaded:
139
  vision_tower.load_model()
140
+ vision_tower.to(device=device, dtype=torch.float16)
 
 
141
  image_processor = vision_tower.image_processor
142
 
143
  if hasattr(model.config, "max_sequence_length"):
llava/model/language_model/mpt/attention.py CHANGED
@@ -151,7 +151,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softma
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
153
 
154
- Using torch or triton attention implemetation enables user to also use
155
  additive bias.
156
  """
157
 
@@ -204,7 +204,7 @@ class MultiheadAttention(nn.Module):
204
  class MultiQueryAttention(nn.Module):
205
  """Multi-Query self attention.
206
 
207
- Using torch or triton attention implemetation enables user to also use
208
  additive bias.
209
  """
210
 
@@ -297,4 +297,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None
297
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
298
  alibi_bias = alibi_bias * slopes
299
  return alibi_bias.to(dtype=dtype)
300
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
153
 
154
+ Using torch or triton attention implementation enables user to also use
155
  additive bias.
156
  """
157
 
 
204
  class MultiQueryAttention(nn.Module):
205
  """Multi-Query self attention.
206
 
207
+ Using torch or triton attention implementation enables user to also use
208
  additive bias.
209
  """
210
 
 
297
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
298
  alibi_bias = alibi_bias * slopes
299
  return alibi_bias.to(dtype=dtype)
300
+ ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
llava/model/llava_arch.py CHANGED
@@ -47,12 +47,19 @@ class LlavaMetaModel:
47
 
48
  self.config.mm_vision_tower = vision_tower
49
 
50
- vision_tower = build_vision_tower(model_args)
 
51
 
52
- if fsdp is not None and len(fsdp) > 0:
53
- self.vision_tower = [vision_tower]
 
 
54
  else:
55
- self.vision_tower = vision_tower
 
 
 
 
56
 
57
  self.config.use_mm_proj = True
58
  self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
@@ -60,7 +67,8 @@ class LlavaMetaModel:
60
  self.config.mm_vision_select_layer = mm_vision_select_layer
61
  self.config.mm_vision_select_feature = mm_vision_select_feature
62
 
63
- self.mm_projector = build_vision_projector(self.config)
 
64
 
65
  if pretrain_mm_mlp_adapter is not None:
66
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
 
47
 
48
  self.config.mm_vision_tower = vision_tower
49
 
50
+ if self.get_vision_tower() is None:
51
+ vision_tower = build_vision_tower(model_args)
52
 
53
+ if fsdp is not None and len(fsdp) > 0:
54
+ self.vision_tower = [vision_tower]
55
+ else:
56
+ self.vision_tower = vision_tower
57
  else:
58
+ if fsdp is not None and len(fsdp) > 0:
59
+ vision_tower = self.vision_tower[0]
60
+ else:
61
+ vision_tower = self.vision_tower
62
+ vision_tower.load_model()
63
 
64
  self.config.use_mm_proj = True
65
  self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
 
67
  self.config.mm_vision_select_layer = mm_vision_select_layer
68
  self.config.mm_vision_select_feature = mm_vision_select_feature
69
 
70
+ if getattr(self, 'mm_projector', None) is None:
71
+ self.mm_projector = build_vision_projector(self.config)
72
 
73
  if pretrain_mm_mlp_adapter is not None:
74
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
llava/serve/cli.py CHANGED
@@ -5,7 +5,7 @@ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_S
5
  from llava.conversation import conv_templates, SeparatorStyle
6
  from llava.model.builder import load_pretrained_model
7
  from llava.utils import disable_torch_init
8
- from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
 
10
  from PIL import Image
11
 
@@ -16,7 +16,7 @@ from transformers import TextStreamer
16
 
17
 
18
  def load_image(image_file):
19
- if image_file.startswith('http') or image_file.startswith('https'):
20
  response = requests.get(image_file)
21
  image = Image.open(BytesIO(response.content)).convert('RGB')
22
  else:
@@ -29,7 +29,7 @@ def main(args):
29
  disable_torch_init()
30
 
31
  model_name = get_model_name_from_path(args.model_path)
32
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
33
 
34
  if 'llama-2' in model_name.lower():
35
  conv_mode = "llava_llama_2"
@@ -52,7 +52,12 @@ def main(args):
52
  roles = conv.roles
53
 
54
  image = load_image(args.image_file)
55
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
 
 
 
 
 
56
 
57
  while True:
58
  try:
@@ -90,8 +95,8 @@ def main(args):
90
  input_ids,
91
  images=image_tensor,
92
  do_sample=True,
93
- temperature=0.2,
94
- max_new_tokens=1024,
95
  streamer=streamer,
96
  use_cache=True,
97
  stopping_criteria=[stopping_criteria])
@@ -108,12 +113,13 @@ if __name__ == "__main__":
108
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
109
  parser.add_argument("--model-base", type=str, default=None)
110
  parser.add_argument("--image-file", type=str, required=True)
111
- parser.add_argument("--num-gpus", type=int, default=1)
112
  parser.add_argument("--conv-mode", type=str, default=None)
113
  parser.add_argument("--temperature", type=float, default=0.2)
114
  parser.add_argument("--max-new-tokens", type=int, default=512)
115
  parser.add_argument("--load-8bit", action="store_true")
116
  parser.add_argument("--load-4bit", action="store_true")
117
  parser.add_argument("--debug", action="store_true")
 
118
  args = parser.parse_args()
119
  main(args)
 
5
  from llava.conversation import conv_templates, SeparatorStyle
6
  from llava.model.builder import load_pretrained_model
7
  from llava.utils import disable_torch_init
8
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
 
10
  from PIL import Image
11
 
 
16
 
17
 
18
  def load_image(image_file):
19
+ if image_file.startswith('http://') or image_file.startswith('https://'):
20
  response = requests.get(image_file)
21
  image = Image.open(BytesIO(response.content)).convert('RGB')
22
  else:
 
29
  disable_torch_init()
30
 
31
  model_name = get_model_name_from_path(args.model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
 
34
  if 'llama-2' in model_name.lower():
35
  conv_mode = "llava_llama_2"
 
52
  roles = conv.roles
53
 
54
  image = load_image(args.image_file)
55
+ # Similar operation in model_worker.py
56
+ image_tensor = process_images([image], image_processor, args)
57
+ if type(image_tensor) is list:
58
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
59
+ else:
60
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
61
 
62
  while True:
63
  try:
 
95
  input_ids,
96
  images=image_tensor,
97
  do_sample=True,
98
+ temperature=args.temperature,
99
+ max_new_tokens=args.max_new_tokens,
100
  streamer=streamer,
101
  use_cache=True,
102
  stopping_criteria=[stopping_criteria])
 
113
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
114
  parser.add_argument("--model-base", type=str, default=None)
115
  parser.add_argument("--image-file", type=str, required=True)
116
+ parser.add_argument("--device", type=str, default="cuda")
117
  parser.add_argument("--conv-mode", type=str, default=None)
118
  parser.add_argument("--temperature", type=float, default=0.2)
119
  parser.add_argument("--max-new-tokens", type=int, default=512)
120
  parser.add_argument("--load-8bit", action="store_true")
121
  parser.add_argument("--load-4bit", action="store_true")
122
  parser.add_argument("--debug", action="store_true")
123
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
124
  args = parser.parse_args()
125
  main(args)
llava/serve/model_worker.py CHANGED
@@ -45,7 +45,7 @@ class ModelWorker:
45
  def __init__(self, controller_addr, worker_addr,
46
  worker_id, no_register,
47
  model_path, model_base, model_name,
48
- load_8bit, load_4bit):
49
  self.controller_addr = controller_addr
50
  self.worker_addr = worker_addr
51
  self.worker_id = worker_id
@@ -60,9 +60,10 @@ class ModelWorker:
60
  else:
61
  self.model_name = model_name
62
 
 
63
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
64
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
65
- model_path, model_base, self.model_name, load_8bit, load_4bit)
66
  self.is_multimodal = 'llava' in self.model_name.lower()
67
 
68
  if not no_register:
@@ -159,7 +160,7 @@ class ModelWorker:
159
  stop_str = params.get("stop", None)
160
  do_sample = True if temperature > 0.001 else False
161
 
162
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
163
  keywords = [stop_str]
164
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
165
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
@@ -258,6 +259,7 @@ if __name__ == "__main__":
258
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
259
  parser.add_argument("--model-base", type=str, default=None)
260
  parser.add_argument("--model-name", type=str)
 
261
  parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
262
  parser.add_argument("--limit-model-concurrency", type=int, default=5)
263
  parser.add_argument("--stream-interval", type=int, default=1)
@@ -278,5 +280,6 @@ if __name__ == "__main__":
278
  args.model_base,
279
  args.model_name,
280
  args.load_8bit,
281
- args.load_4bit)
 
282
  uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
45
  def __init__(self, controller_addr, worker_addr,
46
  worker_id, no_register,
47
  model_path, model_base, model_name,
48
+ load_8bit, load_4bit, device):
49
  self.controller_addr = controller_addr
50
  self.worker_addr = worker_addr
51
  self.worker_id = worker_id
 
60
  else:
61
  self.model_name = model_name
62
 
63
+ self.device = device
64
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
  self.is_multimodal = 'llava' in self.model_name.lower()
68
 
69
  if not no_register:
 
160
  stop_str = params.get("stop", None)
161
  do_sample = True if temperature > 0.001 else False
162
 
163
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
164
  keywords = [stop_str]
165
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
166
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
 
259
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
260
  parser.add_argument("--model-base", type=str, default=None)
261
  parser.add_argument("--model-name", type=str)
262
+ parser.add_argument("--device", type=str, default="cuda")
263
  parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
264
  parser.add_argument("--limit-model-concurrency", type=int, default=5)
265
  parser.add_argument("--stream-interval", type=int, default=1)
 
280
  args.model_base,
281
  args.model_name,
282
  args.load_8bit,
283
+ args.load_4bit,
284
+ args.device)
285
  uvicorn.run(app, host=args.host, port=args.port, log_level="info")
llava/train/train.py CHANGED
@@ -163,12 +163,14 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
163
  def find_all_linear_names(model):
164
  cls = torch.nn.Linear
165
  lora_module_names = set()
 
166
  for name, module in model.named_modules():
 
 
167
  if isinstance(module, cls):
168
  names = name.split('.')
169
  lora_module_names.add(names[0] if len(names) == 1 else names[-1])
170
 
171
-
172
  if 'lm_head' in lora_module_names: # needed for 16-bit
173
  lora_module_names.remove('lm_head')
174
  return list(lora_module_names)
 
163
  def find_all_linear_names(model):
164
  cls = torch.nn.Linear
165
  lora_module_names = set()
166
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
167
  for name, module in model.named_modules():
168
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
169
+ continue
170
  if isinstance(module, cls):
171
  names = name.split('.')
172
  lora_module_names.add(names[0] if len(names) == 1 else names[-1])
173
 
 
174
  if 'lm_head' in lora_module_names: # needed for 16-bit
175
  lora_module_names.remove('lm_head')
176
  return list(lora_module_names)