JHong
commited on
Commit
·
e2d1b00
1
Parent(s):
a45bc4e
Add application file
Browse files- app.py +4 -4
- examples/CXR22_IM-0810-1001.png +0 -0
- examples/CXR628_IM-2208-3001.png +0 -0
- llava/eval/eval_science_qa.py +25 -10
- llava/eval/model_vqa.py +1 -1
- llava/eval/model_vqa_loader.py +1 -2
- llava/eval/model_vqa_science.py +10 -4
- llava/eval/summarize_gpt_review.py +16 -6
- llava/mm_utils.py +6 -3
- llava/model/builder.py +2 -5
- llava/model/language_model/mpt/attention.py +3 -3
- llava/model/llava_arch.py +13 -5
- llava/serve/cli.py +13 -7
- llava/serve/model_worker.py +7 -4
- llava/train/train.py +3 -1
app.py
CHANGED
@@ -399,12 +399,12 @@ def build_demo(embed_mode):
|
|
399 |
gr.Examples(
|
400 |
examples=[
|
401 |
[
|
402 |
-
f"{cur_dir}/examples/
|
403 |
-
"
|
404 |
],
|
405 |
[
|
406 |
-
f"{cur_dir}/examples/
|
407 |
-
"
|
408 |
],
|
409 |
],
|
410 |
inputs=[imagebox, textbox],
|
|
|
399 |
gr.Examples(
|
400 |
examples=[
|
401 |
[
|
402 |
+
f"{cur_dir}/examples/CXR628_IM-2208-3001.png",
|
403 |
+
"Is there any indication of an enlarged heart based on this image?",
|
404 |
],
|
405 |
[
|
406 |
+
f"{cur_dir}/examples/CXR22_IM-0810-1001.png",
|
407 |
+
"nCan you identify any signs of pulmonary fibrosis?",
|
408 |
],
|
409 |
],
|
410 |
inputs=[imagebox, textbox],
|
examples/CXR22_IM-0810-1001.png
ADDED
![]() |
examples/CXR628_IM-2208-3001.png
ADDED
![]() |
llava/eval/eval_science_qa.py
CHANGED
@@ -32,6 +32,7 @@ def get_pred_idx(prediction, choices, options):
|
|
32 |
if prediction in options[:len(choices)]:
|
33 |
return options.index(prediction)
|
34 |
else:
|
|
|
35 |
return random.choice(range(len(choices)))
|
36 |
|
37 |
|
@@ -55,16 +56,23 @@ if __name__ == "__main__":
|
|
55 |
|
56 |
for prob_id, prob in split_problems.items():
|
57 |
if prob_id not in predictions:
|
58 |
-
|
59 |
-
|
60 |
-
pred_text = pred['text']
|
61 |
-
|
62 |
-
pattern = re.compile(r'The answer is ([A-Z]).')
|
63 |
-
res = pattern.findall(pred_text)
|
64 |
-
if len(res) == 1:
|
65 |
-
answer = res[0] # 'A', 'B', ...
|
66 |
else:
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
pred_idx = get_pred_idx(answer, prob['choices'], args.options)
|
70 |
|
@@ -87,7 +95,14 @@ if __name__ == "__main__":
|
|
87 |
|
88 |
correct = len(results['correct'])
|
89 |
total = len(results['correct']) + len(results['incorrect'])
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
sqa_results['acc'] = correct / total * 100
|
93 |
sqa_results['correct'] = correct
|
|
|
32 |
if prediction in options[:len(choices)]:
|
33 |
return options.index(prediction)
|
34 |
else:
|
35 |
+
return -1
|
36 |
return random.choice(range(len(choices)))
|
37 |
|
38 |
|
|
|
56 |
|
57 |
for prob_id, prob in split_problems.items():
|
58 |
if prob_id not in predictions:
|
59 |
+
pred = {'text': 'FAILED', 'prompt': 'Unknown'}
|
60 |
+
pred_text = 'FAILED'
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
else:
|
62 |
+
pred = predictions[prob_id]
|
63 |
+
pred_text = pred['text']
|
64 |
+
|
65 |
+
if pred_text in args.options:
|
66 |
+
answer = pred_text
|
67 |
+
elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
|
68 |
+
answer = pred_text[0]
|
69 |
+
else:
|
70 |
+
pattern = re.compile(r'The answer is ([A-Z]).')
|
71 |
+
res = pattern.findall(pred_text)
|
72 |
+
if len(res) == 1:
|
73 |
+
answer = res[0] # 'A', 'B', ...
|
74 |
+
else:
|
75 |
+
answer = "FAILED"
|
76 |
|
77 |
pred_idx = get_pred_idx(answer, prob['choices'], args.options)
|
78 |
|
|
|
95 |
|
96 |
correct = len(results['correct'])
|
97 |
total = len(results['correct']) + len(results['incorrect'])
|
98 |
+
|
99 |
+
###### IMG ######
|
100 |
+
multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
|
101 |
+
multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
|
102 |
+
multimodal_total = multimodal_correct + multimodal_incorrect
|
103 |
+
###### IMG ######
|
104 |
+
|
105 |
+
print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
|
106 |
|
107 |
sqa_results['acc'] = correct / total * 100
|
108 |
sqa_results['correct'] = correct
|
llava/eval/model_vqa.py
CHANGED
@@ -66,7 +66,7 @@ def eval_model(args):
|
|
66 |
output_ids = model.generate(
|
67 |
input_ids,
|
68 |
images=image_tensor.unsqueeze(0).half().cuda(),
|
69 |
-
do_sample=True,
|
70 |
temperature=args.temperature,
|
71 |
top_p=args.top_p,
|
72 |
num_beams=args.num_beams,
|
|
|
66 |
output_ids = model.generate(
|
67 |
input_ids,
|
68 |
images=image_tensor.unsqueeze(0).half().cuda(),
|
69 |
+
do_sample=True if args.temperature > 0 else False,
|
70 |
temperature=args.temperature,
|
71 |
top_p=args.top_p,
|
72 |
num_beams=args.num_beams,
|
llava/eval/model_vqa_loader.py
CHANGED
@@ -104,7 +104,6 @@ def eval_model(args):
|
|
104 |
top_p=args.top_p,
|
105 |
num_beams=args.num_beams,
|
106 |
max_new_tokens=128,
|
107 |
-
# max_length=64,
|
108 |
use_cache=True)
|
109 |
|
110 |
input_token_len = input_ids.shape[1]
|
@@ -124,7 +123,7 @@ def eval_model(args):
|
|
124 |
"answer_id": ans_id,
|
125 |
"model_id": model_name,
|
126 |
"metadata": {}}) + "\n")
|
127 |
-
ans_file.flush()
|
128 |
ans_file.close()
|
129 |
|
130 |
if __name__ == "__main__":
|
|
|
104 |
top_p=args.top_p,
|
105 |
num_beams=args.num_beams,
|
106 |
max_new_tokens=128,
|
|
|
107 |
use_cache=True)
|
108 |
|
109 |
input_token_len = input_ids.shape[1]
|
|
|
123 |
"answer_id": ans_id,
|
124 |
"model_id": model_name,
|
125 |
"metadata": {}}) + "\n")
|
126 |
+
# ans_file.flush()
|
127 |
ans_file.close()
|
128 |
|
129 |
if __name__ == "__main__":
|
llava/eval/model_vqa_science.py
CHANGED
@@ -57,6 +57,10 @@ def eval_model(args):
|
|
57 |
else:
|
58 |
images = None
|
59 |
|
|
|
|
|
|
|
|
|
60 |
conv = conv_templates[args.conv_mode].copy()
|
61 |
conv.append_message(conv.roles[0], qs)
|
62 |
conv.append_message(conv.roles[1], None)
|
@@ -72,8 +76,8 @@ def eval_model(args):
|
|
72 |
output_ids = model.generate(
|
73 |
input_ids,
|
74 |
images=images,
|
75 |
-
do_sample=True,
|
76 |
-
temperature=
|
77 |
max_new_tokens=1024,
|
78 |
use_cache=True,
|
79 |
stopping_criteria=stopping_criteria,
|
@@ -98,8 +102,8 @@ def eval_model(args):
|
|
98 |
output_ids = model.generate(
|
99 |
input_ids,
|
100 |
images=images,
|
101 |
-
do_sample=True,
|
102 |
-
temperature=
|
103 |
max_new_tokens=64,
|
104 |
use_cache=True,
|
105 |
stopping_criteria=[stopping_criteria])
|
@@ -135,7 +139,9 @@ if __name__ == "__main__":
|
|
135 |
parser.add_argument("--conv-mode", type=str, default="llava_v0")
|
136 |
parser.add_argument("--num-chunks", type=int, default=1)
|
137 |
parser.add_argument("--chunk-idx", type=int, default=0)
|
|
|
138 |
parser.add_argument("--answer-prompter", action="store_true")
|
|
|
139 |
args = parser.parse_args()
|
140 |
|
141 |
eval_model(args)
|
|
|
57 |
else:
|
58 |
images = None
|
59 |
|
60 |
+
if args.single_pred_prompt:
|
61 |
+
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
|
62 |
+
cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
|
63 |
+
|
64 |
conv = conv_templates[args.conv_mode].copy()
|
65 |
conv.append_message(conv.roles[0], qs)
|
66 |
conv.append_message(conv.roles[1], None)
|
|
|
76 |
output_ids = model.generate(
|
77 |
input_ids,
|
78 |
images=images,
|
79 |
+
do_sample=True if args.temperature > 0 else False,
|
80 |
+
temperature=args.temperature,
|
81 |
max_new_tokens=1024,
|
82 |
use_cache=True,
|
83 |
stopping_criteria=stopping_criteria,
|
|
|
102 |
output_ids = model.generate(
|
103 |
input_ids,
|
104 |
images=images,
|
105 |
+
do_sample=True if args.temperature > 0 else False,
|
106 |
+
temperature=args.temperature,
|
107 |
max_new_tokens=64,
|
108 |
use_cache=True,
|
109 |
stopping_criteria=[stopping_criteria])
|
|
|
139 |
parser.add_argument("--conv-mode", type=str, default="llava_v0")
|
140 |
parser.add_argument("--num-chunks", type=int, default=1)
|
141 |
parser.add_argument("--chunk-idx", type=int, default=0)
|
142 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
143 |
parser.add_argument("--answer-prompter", action="store_true")
|
144 |
+
parser.add_argument("--single-pred-prompt", action="store_true")
|
145 |
args = parser.parse_args()
|
146 |
|
147 |
eval_model(args)
|
llava/eval/summarize_gpt_review.py
CHANGED
@@ -9,8 +9,10 @@ import argparse
|
|
9 |
def parse_args():
|
10 |
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
|
11 |
parser.add_argument('-d', '--dir', default=None)
|
12 |
-
parser.add_argument('-
|
13 |
-
parser.add_argument('-
|
|
|
|
|
14 |
return parser.parse_args()
|
15 |
|
16 |
|
@@ -20,19 +22,27 @@ if __name__ == '__main__':
|
|
20 |
if args.ignore is not None:
|
21 |
args.ignore = [int(x) for x in args.ignore]
|
22 |
|
23 |
-
if
|
24 |
review_files = args.files
|
25 |
else:
|
26 |
-
review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))]
|
27 |
|
28 |
for review_file in sorted(review_files):
|
29 |
config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
scores = defaultdict(list)
|
31 |
print(config)
|
32 |
with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
|
33 |
for review_str in f:
|
34 |
review = json.loads(review_str)
|
35 |
-
if
|
36 |
continue
|
37 |
if 'category' in review:
|
38 |
scores[review['category']].append(review['tuple'])
|
@@ -46,5 +56,5 @@ if __name__ == '__main__':
|
|
46 |
stats = np.asarray(v).mean(0).tolist()
|
47 |
stats = [round(x, 3) for x in stats]
|
48 |
# print(k, stats, round(stats[1]/stats[0]*100, 1))
|
49 |
-
print(k, round(stats[1]/stats[0]*100, 1))
|
50 |
print('=================================')
|
|
|
9 |
def parse_args():
|
10 |
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
|
11 |
parser.add_argument('-d', '--dir', default=None)
|
12 |
+
parser.add_argument('-v', '--version', default=None)
|
13 |
+
parser.add_argument('-s', '--select', nargs='*', default=None)
|
14 |
+
parser.add_argument('-f', '--files', nargs='*', default=[])
|
15 |
+
parser.add_argument('-i', '--ignore', nargs='*', default=[])
|
16 |
return parser.parse_args()
|
17 |
|
18 |
|
|
|
22 |
if args.ignore is not None:
|
23 |
args.ignore = [int(x) for x in args.ignore]
|
24 |
|
25 |
+
if len(args.files) > 0:
|
26 |
review_files = args.files
|
27 |
else:
|
28 |
+
review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
|
29 |
|
30 |
for review_file in sorted(review_files):
|
31 |
config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
|
32 |
+
if args.select is not None and any(x not in config for x in args.select):
|
33 |
+
continue
|
34 |
+
if '0613' in config:
|
35 |
+
version = '0613'
|
36 |
+
else:
|
37 |
+
version = '0314'
|
38 |
+
if args.version is not None and args.version != version:
|
39 |
+
continue
|
40 |
scores = defaultdict(list)
|
41 |
print(config)
|
42 |
with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
|
43 |
for review_str in f:
|
44 |
review = json.loads(review_str)
|
45 |
+
if review['question_id'] in args.ignore:
|
46 |
continue
|
47 |
if 'category' in review:
|
48 |
scores[review['category']].append(review['tuple'])
|
|
|
56 |
stats = np.asarray(v).mean(0).tolist()
|
57 |
stats = [round(x, 3) for x in stats]
|
58 |
# print(k, stats, round(stats[1]/stats[0]*100, 1))
|
59 |
+
print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
|
60 |
print('=================================')
|
llava/mm_utils.py
CHANGED
@@ -77,23 +77,26 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
77 |
def __init__(self, keywords, tokenizer, input_ids):
|
78 |
self.keywords = keywords
|
79 |
self.keyword_ids = []
|
|
|
80 |
for keyword in keywords:
|
81 |
cur_keyword_ids = tokenizer(keyword).input_ids
|
82 |
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
83 |
cur_keyword_ids = cur_keyword_ids[1:]
|
|
|
|
|
84 |
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
85 |
self.tokenizer = tokenizer
|
86 |
self.start_len = input_ids.shape[1]
|
87 |
|
88 |
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
89 |
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
90 |
-
offset = min(output_ids.shape[1] - self.start_len,
|
91 |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
92 |
for keyword_id in self.keyword_ids:
|
93 |
-
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
94 |
return True
|
95 |
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
96 |
for keyword in self.keywords:
|
97 |
if keyword in outputs:
|
98 |
return True
|
99 |
-
return False
|
|
|
77 |
def __init__(self, keywords, tokenizer, input_ids):
|
78 |
self.keywords = keywords
|
79 |
self.keyword_ids = []
|
80 |
+
self.max_keyword_len = 0
|
81 |
for keyword in keywords:
|
82 |
cur_keyword_ids = tokenizer(keyword).input_ids
|
83 |
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
84 |
cur_keyword_ids = cur_keyword_ids[1:]
|
85 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
86 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
87 |
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
88 |
self.tokenizer = tokenizer
|
89 |
self.start_len = input_ids.shape[1]
|
90 |
|
91 |
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
92 |
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
93 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
94 |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
95 |
for keyword_id in self.keyword_ids:
|
96 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
97 |
return True
|
98 |
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
99 |
for keyword in self.keywords:
|
100 |
if keyword in outputs:
|
101 |
return True
|
102 |
+
return False
|
llava/model/builder.py
CHANGED
@@ -23,9 +23,8 @@ from llava.model import *
|
|
23 |
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
|
27 |
kwargs = {"device_map": device_map}
|
28 |
-
kwargs["offload_folder"] = "offload"
|
29 |
|
30 |
if load_8bit:
|
31 |
kwargs['load_in_8bit'] = True
|
@@ -138,9 +137,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
138 |
vision_tower = model.get_vision_tower()
|
139 |
if not vision_tower.is_loaded:
|
140 |
vision_tower.load_model()
|
141 |
-
|
142 |
-
|
143 |
-
vision_tower.to(device=model.device, dtype=torch.float16)
|
144 |
image_processor = vision_tower.image_processor
|
145 |
|
146 |
if hasattr(model.config, "max_sequence_length"):
|
|
|
23 |
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
27 |
kwargs = {"device_map": device_map}
|
|
|
28 |
|
29 |
if load_8bit:
|
30 |
kwargs['load_in_8bit'] = True
|
|
|
137 |
vision_tower = model.get_vision_tower()
|
138 |
if not vision_tower.is_loaded:
|
139 |
vision_tower.load_model()
|
140 |
+
vision_tower.to(device=device, dtype=torch.float16)
|
|
|
|
|
141 |
image_processor = vision_tower.image_processor
|
142 |
|
143 |
if hasattr(model.config, "max_sequence_length"):
|
llava/model/language_model/mpt/attention.py
CHANGED
@@ -151,7 +151,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softma
|
|
151 |
class MultiheadAttention(nn.Module):
|
152 |
"""Multi-head self attention.
|
153 |
|
154 |
-
Using torch or triton attention
|
155 |
additive bias.
|
156 |
"""
|
157 |
|
@@ -204,7 +204,7 @@ class MultiheadAttention(nn.Module):
|
|
204 |
class MultiQueryAttention(nn.Module):
|
205 |
"""Multi-Query self attention.
|
206 |
|
207 |
-
Using torch or triton attention
|
208 |
additive bias.
|
209 |
"""
|
210 |
|
@@ -297,4 +297,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None
|
|
297 |
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
298 |
alibi_bias = alibi_bias * slopes
|
299 |
return alibi_bias.to(dtype=dtype)
|
300 |
-
ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
|
|
|
151 |
class MultiheadAttention(nn.Module):
|
152 |
"""Multi-head self attention.
|
153 |
|
154 |
+
Using torch or triton attention implementation enables user to also use
|
155 |
additive bias.
|
156 |
"""
|
157 |
|
|
|
204 |
class MultiQueryAttention(nn.Module):
|
205 |
"""Multi-Query self attention.
|
206 |
|
207 |
+
Using torch or triton attention implementation enables user to also use
|
208 |
additive bias.
|
209 |
"""
|
210 |
|
|
|
297 |
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
298 |
alibi_bias = alibi_bias * slopes
|
299 |
return alibi_bias.to(dtype=dtype)
|
300 |
+
ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
|
llava/model/llava_arch.py
CHANGED
@@ -47,12 +47,19 @@ class LlavaMetaModel:
|
|
47 |
|
48 |
self.config.mm_vision_tower = vision_tower
|
49 |
|
50 |
-
|
|
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
else:
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
57 |
self.config.use_mm_proj = True
|
58 |
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
@@ -60,7 +67,8 @@ class LlavaMetaModel:
|
|
60 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
61 |
self.config.mm_vision_select_feature = mm_vision_select_feature
|
62 |
|
63 |
-
self
|
|
|
64 |
|
65 |
if pretrain_mm_mlp_adapter is not None:
|
66 |
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
|
|
47 |
|
48 |
self.config.mm_vision_tower = vision_tower
|
49 |
|
50 |
+
if self.get_vision_tower() is None:
|
51 |
+
vision_tower = build_vision_tower(model_args)
|
52 |
|
53 |
+
if fsdp is not None and len(fsdp) > 0:
|
54 |
+
self.vision_tower = [vision_tower]
|
55 |
+
else:
|
56 |
+
self.vision_tower = vision_tower
|
57 |
else:
|
58 |
+
if fsdp is not None and len(fsdp) > 0:
|
59 |
+
vision_tower = self.vision_tower[0]
|
60 |
+
else:
|
61 |
+
vision_tower = self.vision_tower
|
62 |
+
vision_tower.load_model()
|
63 |
|
64 |
self.config.use_mm_proj = True
|
65 |
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
|
|
67 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
68 |
self.config.mm_vision_select_feature = mm_vision_select_feature
|
69 |
|
70 |
+
if getattr(self, 'mm_projector', None) is None:
|
71 |
+
self.mm_projector = build_vision_projector(self.config)
|
72 |
|
73 |
if pretrain_mm_mlp_adapter is not None:
|
74 |
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
llava/serve/cli.py
CHANGED
@@ -5,7 +5,7 @@ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_S
|
|
5 |
from llava.conversation import conv_templates, SeparatorStyle
|
6 |
from llava.model.builder import load_pretrained_model
|
7 |
from llava.utils import disable_torch_init
|
8 |
-
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
9 |
|
10 |
from PIL import Image
|
11 |
|
@@ -16,7 +16,7 @@ from transformers import TextStreamer
|
|
16 |
|
17 |
|
18 |
def load_image(image_file):
|
19 |
-
if image_file.startswith('http') or image_file.startswith('https'):
|
20 |
response = requests.get(image_file)
|
21 |
image = Image.open(BytesIO(response.content)).convert('RGB')
|
22 |
else:
|
@@ -29,7 +29,7 @@ def main(args):
|
|
29 |
disable_torch_init()
|
30 |
|
31 |
model_name = get_model_name_from_path(args.model_path)
|
32 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
|
33 |
|
34 |
if 'llama-2' in model_name.lower():
|
35 |
conv_mode = "llava_llama_2"
|
@@ -52,7 +52,12 @@ def main(args):
|
|
52 |
roles = conv.roles
|
53 |
|
54 |
image = load_image(args.image_file)
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
while True:
|
58 |
try:
|
@@ -90,8 +95,8 @@ def main(args):
|
|
90 |
input_ids,
|
91 |
images=image_tensor,
|
92 |
do_sample=True,
|
93 |
-
temperature=
|
94 |
-
max_new_tokens=
|
95 |
streamer=streamer,
|
96 |
use_cache=True,
|
97 |
stopping_criteria=[stopping_criteria])
|
@@ -108,12 +113,13 @@ if __name__ == "__main__":
|
|
108 |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
109 |
parser.add_argument("--model-base", type=str, default=None)
|
110 |
parser.add_argument("--image-file", type=str, required=True)
|
111 |
-
parser.add_argument("--
|
112 |
parser.add_argument("--conv-mode", type=str, default=None)
|
113 |
parser.add_argument("--temperature", type=float, default=0.2)
|
114 |
parser.add_argument("--max-new-tokens", type=int, default=512)
|
115 |
parser.add_argument("--load-8bit", action="store_true")
|
116 |
parser.add_argument("--load-4bit", action="store_true")
|
117 |
parser.add_argument("--debug", action="store_true")
|
|
|
118 |
args = parser.parse_args()
|
119 |
main(args)
|
|
|
5 |
from llava.conversation import conv_templates, SeparatorStyle
|
6 |
from llava.model.builder import load_pretrained_model
|
7 |
from llava.utils import disable_torch_init
|
8 |
+
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
9 |
|
10 |
from PIL import Image
|
11 |
|
|
|
16 |
|
17 |
|
18 |
def load_image(image_file):
|
19 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
20 |
response = requests.get(image_file)
|
21 |
image = Image.open(BytesIO(response.content)).convert('RGB')
|
22 |
else:
|
|
|
29 |
disable_torch_init()
|
30 |
|
31 |
model_name = get_model_name_from_path(args.model_path)
|
32 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
33 |
|
34 |
if 'llama-2' in model_name.lower():
|
35 |
conv_mode = "llava_llama_2"
|
|
|
52 |
roles = conv.roles
|
53 |
|
54 |
image = load_image(args.image_file)
|
55 |
+
# Similar operation in model_worker.py
|
56 |
+
image_tensor = process_images([image], image_processor, args)
|
57 |
+
if type(image_tensor) is list:
|
58 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
59 |
+
else:
|
60 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
61 |
|
62 |
while True:
|
63 |
try:
|
|
|
95 |
input_ids,
|
96 |
images=image_tensor,
|
97 |
do_sample=True,
|
98 |
+
temperature=args.temperature,
|
99 |
+
max_new_tokens=args.max_new_tokens,
|
100 |
streamer=streamer,
|
101 |
use_cache=True,
|
102 |
stopping_criteria=[stopping_criteria])
|
|
|
113 |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
114 |
parser.add_argument("--model-base", type=str, default=None)
|
115 |
parser.add_argument("--image-file", type=str, required=True)
|
116 |
+
parser.add_argument("--device", type=str, default="cuda")
|
117 |
parser.add_argument("--conv-mode", type=str, default=None)
|
118 |
parser.add_argument("--temperature", type=float, default=0.2)
|
119 |
parser.add_argument("--max-new-tokens", type=int, default=512)
|
120 |
parser.add_argument("--load-8bit", action="store_true")
|
121 |
parser.add_argument("--load-4bit", action="store_true")
|
122 |
parser.add_argument("--debug", action="store_true")
|
123 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
124 |
args = parser.parse_args()
|
125 |
main(args)
|
llava/serve/model_worker.py
CHANGED
@@ -45,7 +45,7 @@ class ModelWorker:
|
|
45 |
def __init__(self, controller_addr, worker_addr,
|
46 |
worker_id, no_register,
|
47 |
model_path, model_base, model_name,
|
48 |
-
load_8bit, load_4bit):
|
49 |
self.controller_addr = controller_addr
|
50 |
self.worker_addr = worker_addr
|
51 |
self.worker_id = worker_id
|
@@ -60,9 +60,10 @@ class ModelWorker:
|
|
60 |
else:
|
61 |
self.model_name = model_name
|
62 |
|
|
|
63 |
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
64 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
65 |
-
model_path, model_base, self.model_name, load_8bit, load_4bit)
|
66 |
self.is_multimodal = 'llava' in self.model_name.lower()
|
67 |
|
68 |
if not no_register:
|
@@ -159,7 +160,7 @@ class ModelWorker:
|
|
159 |
stop_str = params.get("stop", None)
|
160 |
do_sample = True if temperature > 0.001 else False
|
161 |
|
162 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).
|
163 |
keywords = [stop_str]
|
164 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
165 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
@@ -258,6 +259,7 @@ if __name__ == "__main__":
|
|
258 |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
259 |
parser.add_argument("--model-base", type=str, default=None)
|
260 |
parser.add_argument("--model-name", type=str)
|
|
|
261 |
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
262 |
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
263 |
parser.add_argument("--stream-interval", type=int, default=1)
|
@@ -278,5 +280,6 @@ if __name__ == "__main__":
|
|
278 |
args.model_base,
|
279 |
args.model_name,
|
280 |
args.load_8bit,
|
281 |
-
args.load_4bit
|
|
|
282 |
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
45 |
def __init__(self, controller_addr, worker_addr,
|
46 |
worker_id, no_register,
|
47 |
model_path, model_base, model_name,
|
48 |
+
load_8bit, load_4bit, device):
|
49 |
self.controller_addr = controller_addr
|
50 |
self.worker_addr = worker_addr
|
51 |
self.worker_id = worker_id
|
|
|
60 |
else:
|
61 |
self.model_name = model_name
|
62 |
|
63 |
+
self.device = device
|
64 |
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
65 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
66 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
67 |
self.is_multimodal = 'llava' in self.model_name.lower()
|
68 |
|
69 |
if not no_register:
|
|
|
160 |
stop_str = params.get("stop", None)
|
161 |
do_sample = True if temperature > 0.001 else False
|
162 |
|
163 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
164 |
keywords = [stop_str]
|
165 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
166 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
|
|
259 |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
260 |
parser.add_argument("--model-base", type=str, default=None)
|
261 |
parser.add_argument("--model-name", type=str)
|
262 |
+
parser.add_argument("--device", type=str, default="cuda")
|
263 |
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
264 |
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
265 |
parser.add_argument("--stream-interval", type=int, default=1)
|
|
|
280 |
args.model_base,
|
281 |
args.model_name,
|
282 |
args.load_8bit,
|
283 |
+
args.load_4bit,
|
284 |
+
args.device)
|
285 |
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
llava/train/train.py
CHANGED
@@ -163,12 +163,14 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
|
163 |
def find_all_linear_names(model):
|
164 |
cls = torch.nn.Linear
|
165 |
lora_module_names = set()
|
|
|
166 |
for name, module in model.named_modules():
|
|
|
|
|
167 |
if isinstance(module, cls):
|
168 |
names = name.split('.')
|
169 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
170 |
|
171 |
-
|
172 |
if 'lm_head' in lora_module_names: # needed for 16-bit
|
173 |
lora_module_names.remove('lm_head')
|
174 |
return list(lora_module_names)
|
|
|
163 |
def find_all_linear_names(model):
|
164 |
cls = torch.nn.Linear
|
165 |
lora_module_names = set()
|
166 |
+
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
|
167 |
for name, module in model.named_modules():
|
168 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
169 |
+
continue
|
170 |
if isinstance(module, cls):
|
171 |
names = name.split('.')
|
172 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
173 |
|
|
|
174 |
if 'lm_head' in lora_module_names: # needed for 16-bit
|
175 |
lora_module_names.remove('lm_head')
|
176 |
return list(lora_module_names)
|