x-lai commited on
Commit
6144294
·
1 Parent(s): a46000f

Release training script

Browse files

Former-commit-id: 96ec3cdf6a4f6880ac274ddde55537570788ebbb

model/LISA.py CHANGED
@@ -3,14 +3,18 @@ from typing import List
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from peft import (LoraConfig, get_peft_model)
7
  from transformers import BitsAndBytesConfig, CLIPVisionModel
8
 
9
  from transformers import CLIPVisionModel, BitsAndBytesConfig
10
  from .llava.model.llava import LlavaLlamaForCausalLM
11
  from .segment_anything import build_sam_vit_h
12
- from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
13
- DEFAULT_IMAGE_PATCH_TOKEN)
 
 
 
 
14
 
15
  def dice_loss(
16
  inputs: torch.Tensor,
@@ -219,7 +223,9 @@ class LISA(nn.Module):
219
  self.lm.resize_token_embeddings(len(tokenizer))
220
 
221
  for n, p in self.lm.named_parameters():
222
- if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(tokenizer):
 
 
223
  p.requires_grad = True
224
 
225
  # SAM
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from peft import LoraConfig, get_peft_model
7
  from transformers import BitsAndBytesConfig, CLIPVisionModel
8
 
9
  from transformers import CLIPVisionModel, BitsAndBytesConfig
10
  from .llava.model.llava import LlavaLlamaForCausalLM
11
  from .segment_anything import build_sam_vit_h
12
+ from utils.utils import (
13
+ DEFAULT_IM_END_TOKEN,
14
+ DEFAULT_IM_START_TOKEN,
15
+ DEFAULT_IMAGE_PATCH_TOKEN,
16
+ )
17
+
18
 
19
  def dice_loss(
20
  inputs: torch.Tensor,
 
223
  self.lm.resize_token_embeddings(len(tokenizer))
224
 
225
  for n, p in self.lm.named_parameters():
226
+ if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(
227
+ tokenizer
228
+ ):
229
  p.requires_grad = True
230
 
231
  # SAM
model/llava/eval/model_vqa.py CHANGED
@@ -11,9 +11,14 @@ from llava.conversation import conv_templates
11
  from llava.utils import disable_torch_init
12
  from PIL import Image
13
  from tqdm import tqdm
14
- from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
15
- CLIPImageProcessor, CLIPVisionModel,
16
- StoppingCriteria)
 
 
 
 
 
17
 
18
 
19
  def split_list(lst, n):
 
11
  from llava.utils import disable_torch_init
12
  from PIL import Image
13
  from tqdm import tqdm
14
+ from transformers import (
15
+ AutoConfig,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ CLIPImageProcessor,
19
+ CLIPVisionModel,
20
+ StoppingCriteria,
21
+ )
22
 
23
 
24
  def split_list(lst, n):
model/llava/eval/model_vqa_science.py CHANGED
@@ -11,9 +11,14 @@ from llava.conversation import conv_templates
11
  from llava.utils import disable_torch_init
12
  from PIL import Image
13
  from tqdm import tqdm
14
- from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
15
- CLIPImageProcessor, CLIPVisionModel,
16
- StoppingCriteria)
 
 
 
 
 
17
 
18
 
19
  def split_list(lst, n):
 
11
  from llava.utils import disable_torch_init
12
  from PIL import Image
13
  from tqdm import tqdm
14
+ from transformers import (
15
+ AutoConfig,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ CLIPImageProcessor,
19
+ CLIPVisionModel,
20
+ StoppingCriteria,
21
+ )
22
 
23
 
24
  def split_list(lst, n):
model/llava/eval/run_llava.py CHANGED
@@ -9,9 +9,13 @@ from llava.model import *
9
  from llava.model.utils import KeywordsStoppingCriteria
10
  from llava.utils import disable_torch_init
11
  from PIL import Image
12
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
13
- CLIPImageProcessor, CLIPVisionModel,
14
- StoppingCriteria)
 
 
 
 
15
 
16
  DEFAULT_IMAGE_TOKEN = "<image>"
17
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
9
  from llava.model.utils import KeywordsStoppingCriteria
10
  from llava.utils import disable_torch_init
11
  from PIL import Image
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ CLIPImageProcessor,
16
+ CLIPVisionModel,
17
+ StoppingCriteria,
18
+ )
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
model/llava/eval/run_llava_batch.py CHANGED
@@ -13,9 +13,13 @@ from llava.model import *
13
  from llava.model.utils import KeywordsStoppingCriteria
14
  from llava.utils import disable_torch_init
15
  from PIL import Image
16
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
17
- CLIPImageProcessor, CLIPVisionModel,
18
- StoppingCriteria)
 
 
 
 
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
13
  from llava.model.utils import KeywordsStoppingCriteria
14
  from llava.utils import disable_torch_init
15
  from PIL import Image
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ CLIPImageProcessor,
20
+ CLIPVisionModel,
21
+ StoppingCriteria,
22
+ )
23
 
24
  DEFAULT_IMAGE_TOKEN = "<image>"
25
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
model/llava/eval/run_llava_batch_v2.py CHANGED
@@ -13,9 +13,13 @@ from llava.model import *
13
  from llava.model.utils import KeywordsStoppingCriteria
14
  from llava.utils import disable_torch_init
15
  from PIL import Image
16
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
17
- CLIPImageProcessor, CLIPVisionModel,
18
- StoppingCriteria)
 
 
 
 
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
13
  from llava.model.utils import KeywordsStoppingCriteria
14
  from llava.utils import disable_torch_init
15
  from PIL import Image
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ CLIPImageProcessor,
20
+ CLIPVisionModel,
21
+ StoppingCriteria,
22
+ )
23
 
24
  DEFAULT_IMAGE_TOKEN = "<image>"
25
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
model/llava/eval/run_llava_batch_v3.py CHANGED
@@ -13,9 +13,13 @@ from llava.model import *
13
  from llava.model.utils import KeywordsStoppingCriteria
14
  from llava.utils import disable_torch_init
15
  from PIL import Image
16
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
17
- CLIPImageProcessor, CLIPVisionModel,
18
- StoppingCriteria)
 
 
 
 
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
13
  from llava.model.utils import KeywordsStoppingCriteria
14
  from llava.utils import disable_torch_init
15
  from PIL import Image
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ CLIPImageProcessor,
20
+ CLIPVisionModel,
21
+ StoppingCriteria,
22
+ )
23
 
24
  DEFAULT_IMAGE_TOKEN = "<image>"
25
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
model/llava/model/llava.py CHANGED
@@ -19,11 +19,19 @@ import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from torch.nn import CrossEntropyLoss
22
- from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor,
23
- CLIPVisionModel, LlamaConfig, LlamaForCausalLM,
24
- LlamaModel)
25
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
26
- CausalLMOutputWithPast)
 
 
 
 
 
 
 
 
27
 
28
  DEFAULT_IMAGE_TOKEN = "<image>"
29
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from torch.nn import CrossEntropyLoss
22
+ from transformers import (
23
+ AutoConfig,
24
+ AutoModelForCausalLM,
25
+ CLIPImageProcessor,
26
+ CLIPVisionModel,
27
+ LlamaConfig,
28
+ LlamaForCausalLM,
29
+ LlamaModel,
30
+ )
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ )
35
 
36
  DEFAULT_IMAGE_TOKEN = "<image>"
37
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
model/llava/model/llava_mpt.py CHANGED
@@ -21,10 +21,16 @@ import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
  from torch.nn import CrossEntropyLoss
24
- from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor,
25
- CLIPVisionModel)
26
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
27
- CausalLMOutputWithPast)
 
 
 
 
 
 
28
 
29
  from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
30
 
 
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
  from torch.nn import CrossEntropyLoss
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModelForCausalLM,
27
+ CLIPImageProcessor,
28
+ CLIPVisionModel,
29
+ )
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ )
34
 
35
  from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
36
 
model/llava/model/mpt/adapt_tokenizer.py CHANGED
@@ -1,7 +1,6 @@
1
  from typing import Union
2
 
3
- from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
- PreTrainedTokenizerFast)
5
 
6
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
7
  NUM_SENTINEL_TOKENS: int = 100
 
1
  from typing import Union
2
 
3
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
 
4
 
5
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
6
  NUM_SENTINEL_TOKENS: int = 100
model/llava/model/mpt/hf_prefixlm_converter.py CHANGED
@@ -13,22 +13,26 @@ from typing import Any, Dict, List, Optional, Tuple, Union
13
 
14
  import torch
15
  from transformers.models.bloom.modeling_bloom import (
16
- BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel,
17
- CausalLMOutputWithCrossAttentions, CrossEntropyLoss)
18
- from transformers.models.bloom.modeling_bloom import \
19
- _expand_mask as _expand_mask_bloom
20
- from transformers.models.bloom.modeling_bloom import \
21
- _make_causal_mask as _make_causal_mask_bloom
 
 
 
 
22
  from transformers.models.bloom.modeling_bloom import logging
23
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
24
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
25
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
26
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
27
  from transformers.models.opt.modeling_opt import OPTForCausalLM
28
- from transformers.models.opt.modeling_opt import \
29
- _expand_mask as _expand_mask_opt
30
- from transformers.models.opt.modeling_opt import \
31
- _make_causal_mask as _make_causal_mask_opt
32
 
33
  logger = logging.get_logger(__name__)
34
  _SUPPORTED_GPT_MODELS = (
 
13
 
14
  import torch
15
  from transformers.models.bloom.modeling_bloom import (
16
+ BaseModelOutputWithPastAndCrossAttentions,
17
+ BloomForCausalLM,
18
+ BloomModel,
19
+ CausalLMOutputWithCrossAttentions,
20
+ CrossEntropyLoss,
21
+ )
22
+ from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
23
+ from transformers.models.bloom.modeling_bloom import (
24
+ _make_causal_mask as _make_causal_mask_bloom,
25
+ )
26
  from transformers.models.bloom.modeling_bloom import logging
27
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
28
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
29
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
30
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
31
  from transformers.models.opt.modeling_opt import OPTForCausalLM
32
+ from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
33
+ from transformers.models.opt.modeling_opt import (
34
+ _make_causal_mask as _make_causal_mask_opt,
35
+ )
36
 
37
  logger = logging.get_logger(__name__)
38
  _SUPPORTED_GPT_MODELS = (
model/llava/model/mpt/modeling_mpt.py CHANGED
@@ -9,17 +9,20 @@ from typing import List, Optional, Tuple, Union
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
- from transformers import (PreTrainedModel, PreTrainedTokenizer,
13
- PreTrainedTokenizerFast)
14
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
- CausalLMOutputWithPast)
 
16
 
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .attention import attn_bias_shape, build_attn_bias
19
  from .blocks import MPTBlock
20
  from .configuration_mpt import MPTConfig
21
- from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing,
22
- convert_hf_causal_lm_to_prefix_lm)
 
 
23
  from .meta_init_context import init_empty_weights
24
  from .norm import NORM_CLASS_REGISTRY
25
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast,
16
+ )
17
 
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .attention import attn_bias_shape, build_attn_bias
20
  from .blocks import MPTBlock
21
  from .configuration_mpt import MPTConfig
22
+ from .hf_prefixlm_converter import (
23
+ add_bidirectional_mask_if_missing,
24
+ convert_hf_causal_lm_to_prefix_lm,
25
+ )
26
  from .meta_init_context import init_empty_weights
27
  from .norm import NORM_CLASS_REGISTRY
28
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
model/llava/serve/gradio_web_server.py CHANGED
@@ -9,12 +9,15 @@ from collections import defaultdict
9
  import gradio as gr
10
  import requests
11
  from llava.constants import LOGDIR
12
- from llava.conversation import (SeparatorStyle, conv_templates,
13
- default_conversation)
14
  from llava.serve.gradio_css import code_highlight_css
15
  from llava.serve.gradio_patch import Chatbot as grChatbot
16
- from llava.utils import (build_logger, moderation_msg, server_error_msg,
17
- violates_moderation)
 
 
 
 
18
 
19
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
20
 
 
9
  import gradio as gr
10
  import requests
11
  from llava.constants import LOGDIR
12
+ from llava.conversation import SeparatorStyle, conv_templates, default_conversation
 
13
  from llava.serve.gradio_css import code_highlight_css
14
  from llava.serve.gradio_patch import Chatbot as grChatbot
15
+ from llava.utils import (
16
+ build_logger,
17
+ moderation_msg,
18
+ server_error_msg,
19
+ violates_moderation,
20
+ )
21
 
22
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
23
 
model/llava/train/train.py CHANGED
@@ -715,8 +715,9 @@ def train():
715
  "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
716
  )
717
 
718
- from torch.distributed.fsdp.fully_sharded_data_parallel import \
719
- FullyShardedDataParallel as FSDP
 
720
 
721
  def patch_FSDP_use_orig_params(func):
722
  def wrap_func(*args, **kwargs):
 
715
  "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
716
  )
717
 
718
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
719
+ FullyShardedDataParallel as FSDP,
720
+ )
721
 
722
  def patch_FSDP_use_orig_params(func):
723
  def wrap_func(*args, **kwargs):
model/llava/train/train_mem.py CHANGED
@@ -3,8 +3,7 @@
3
  # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
 
5
  # Need to call this before importing transformers.
6
- from llava.train.llama_flash_attn_monkey_patch import \
7
- replace_llama_attn_with_flash_attn
8
 
9
  replace_llama_attn_with_flash_attn()
10
 
 
3
  # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
 
5
  # Need to call this before importing transformers.
6
+ from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
 
7
 
8
  replace_llama_attn_with_flash_attn()
9
 
model/segment_anything/__init__.py CHANGED
@@ -5,6 +5,11 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  from .automatic_mask_generator import SamAutomaticMaskGenerator
8
- from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
9
- build_sam_vit_l, sam_model_registry)
 
 
 
 
 
10
  from .predictor import SamPredictor
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  from .automatic_mask_generator import SamAutomaticMaskGenerator
8
+ from .build_sam import (
9
+ build_sam,
10
+ build_sam_vit_b,
11
+ build_sam_vit_h,
12
+ build_sam_vit_l,
13
+ sam_model_registry,
14
+ )
15
  from .predictor import SamPredictor
model/segment_anything/automatic_mask_generator.py CHANGED
@@ -12,13 +12,24 @@ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12
 
13
  from .modeling import Sam
14
  from .predictor import SamPredictor
15
- from .utils.amg import (MaskData, area_from_rle, batch_iterator,
16
- batched_mask_to_box, box_xyxy_to_xywh,
17
- build_all_layer_point_grids, calculate_stability_score,
18
- coco_encode_rle, generate_crop_boxes,
19
- is_box_near_crop_edge, mask_to_rle_pytorch,
20
- remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
21
- uncrop_masks, uncrop_points)
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  class SamAutomaticMaskGenerator:
@@ -104,8 +115,7 @@ class SamAutomaticMaskGenerator:
104
  "coco_rle",
105
  ], f"Unknown output_mode {output_mode}."
106
  if output_mode == "coco_rle":
107
- from pycocotools import \
108
- mask as mask_utils # type: ignore # noqa: F401
109
 
110
  if min_mask_region_area > 0:
111
  import cv2 # type: ignore # noqa: F401
 
12
 
13
  from .modeling import Sam
14
  from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
 
34
 
35
  class SamAutomaticMaskGenerator:
 
115
  "coco_rle",
116
  ], f"Unknown output_mode {output_mode}."
117
  if output_mode == "coco_rle":
118
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
 
119
 
120
  if min_mask_region_area > 0:
121
  import cv2 # type: ignore # noqa: F401
model/segment_anything/build_sam.py CHANGED
@@ -8,8 +8,13 @@ from functools import partial
8
 
9
  import torch
10
 
11
- from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam,
12
- TwoWayTransformer)
 
 
 
 
 
13
 
14
 
15
  def build_sam_vit_h(checkpoint=None):
 
8
 
9
  import torch
10
 
11
+ from .modeling import (
12
+ ImageEncoderViT,
13
+ MaskDecoder,
14
+ PromptEncoder,
15
+ Sam,
16
+ TwoWayTransformer,
17
+ )
18
 
19
 
20
  def build_sam_vit_h(checkpoint=None):
train_ds.py CHANGED
@@ -14,8 +14,13 @@ from torch.utils.tensorboard import SummaryWriter
14
 
15
  from model.LISA import LISA
16
  from utils.dataset import HybridDataset, ValDataset, collate_fn
17
- from utils.utils import (AverageMeter, ProgressMeter, Summary, dict_to_cuda,
18
- intersectionAndUnionGPU)
 
 
 
 
 
19
 
20
 
21
  def parse_args(args):
@@ -54,9 +59,7 @@ def parse_args(args):
54
  )
55
  parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
56
  parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
57
- parser.add_argument(
58
- "--val_dataset", default="ReasonSeg|val", type=str
59
- )
60
  parser.add_argument("--dataset_dir", default="./dataset", type=str)
61
  parser.add_argument("--log_base_dir", default="./runs", type=str)
62
  parser.add_argument("--exp_name", default="lisa", type=str)
@@ -87,7 +90,9 @@ def parse_args(args):
87
  parser.add_argument("--exclude_val", action="store_true", default=False)
88
  parser.add_argument("--no_eval", action="store_true", default=False)
89
  parser.add_argument("--eval_only", action="store_true", default=False)
90
- parser.add_argument("--vision_pretrained", default="PATH TO SAM ViT-H Pre-trained Wegiht", type=str)
 
 
91
  parser.add_argument("--weight", default="", type=str)
92
  parser.add_argument("--print_freq", default=1, type=int)
93
  parser.add_argument("--start_epoch", default=0, type=int)
@@ -133,7 +138,7 @@ def main(args):
133
  )
134
 
135
  if args.weight:
136
- state_dict = torch.load(args.weight, map_location='cpu')
137
  model.load_state_dict(state_dict, strict=True)
138
 
139
  world_size = torch.cuda.device_count()
@@ -142,7 +147,10 @@ def main(args):
142
  args.dataset_dir,
143
  tokenizer,
144
  args.vision_tower,
145
- samples_per_epoch=args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size,
 
 
 
146
  precision=args.precision,
147
  image_size=args.image_size,
148
  num_classes_per_sample=args.num_classes_per_sample,
@@ -163,7 +171,9 @@ def main(args):
163
  args.val_dataset,
164
  args.image_size,
165
  )
166
- print(f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples.")
 
 
167
  else:
168
  val_dataset = None
169
  print(f"Training with {len(train_dataset)} examples.")
@@ -215,7 +225,9 @@ def main(args):
215
 
216
  if val_dataset is not None:
217
  assert args.val_batch_size == 1
218
- val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
 
 
219
  val_loader = torch.utils.data.DataLoader(
220
  val_dataset,
221
  batch_size=args.val_batch_size,
@@ -230,13 +242,10 @@ def main(args):
230
  best_score, cur_ciou = 0.0, 0.0
231
 
232
  if args.eval_only:
233
- giou, ciou = validate(
234
- val_loader, model_engine, 0, writer, args
235
- )
236
  exit()
237
 
238
  for epoch in range(args.start_epoch, args.epochs):
239
-
240
  # train for one epoch
241
  train_iter = train(
242
  train_loader,
@@ -249,9 +258,7 @@ def main(args):
249
  )
250
 
251
  if args.no_eval == False:
252
- giou, ciou = validate(
253
- val_loader, model_engine, epoch, writer, args
254
- )
255
  is_best = giou > best_score
256
  best_score = max(giou, best_score)
257
  cur_ciou = ciou if is_best else cur_ciou
 
14
 
15
  from model.LISA import LISA
16
  from utils.dataset import HybridDataset, ValDataset, collate_fn
17
+ from utils.utils import (
18
+ AverageMeter,
19
+ ProgressMeter,
20
+ Summary,
21
+ dict_to_cuda,
22
+ intersectionAndUnionGPU,
23
+ )
24
 
25
 
26
  def parse_args(args):
 
59
  )
60
  parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
61
  parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
62
+ parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
 
 
63
  parser.add_argument("--dataset_dir", default="./dataset", type=str)
64
  parser.add_argument("--log_base_dir", default="./runs", type=str)
65
  parser.add_argument("--exp_name", default="lisa", type=str)
 
90
  parser.add_argument("--exclude_val", action="store_true", default=False)
91
  parser.add_argument("--no_eval", action="store_true", default=False)
92
  parser.add_argument("--eval_only", action="store_true", default=False)
93
+ parser.add_argument(
94
+ "--vision_pretrained", default="PATH TO SAM ViT-H Pre-trained Wegiht", type=str
95
+ )
96
  parser.add_argument("--weight", default="", type=str)
97
  parser.add_argument("--print_freq", default=1, type=int)
98
  parser.add_argument("--start_epoch", default=0, type=int)
 
138
  )
139
 
140
  if args.weight:
141
+ state_dict = torch.load(args.weight, map_location="cpu")
142
  model.load_state_dict(state_dict, strict=True)
143
 
144
  world_size = torch.cuda.device_count()
 
147
  args.dataset_dir,
148
  tokenizer,
149
  args.vision_tower,
150
+ samples_per_epoch=args.batch_size
151
+ * args.grad_accumulation_steps
152
+ * args.steps_per_epoch
153
+ * world_size,
154
  precision=args.precision,
155
  image_size=args.image_size,
156
  num_classes_per_sample=args.num_classes_per_sample,
 
171
  args.val_dataset,
172
  args.image_size,
173
  )
174
+ print(
175
+ f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
176
+ )
177
  else:
178
  val_dataset = None
179
  print(f"Training with {len(train_dataset)} examples.")
 
225
 
226
  if val_dataset is not None:
227
  assert args.val_batch_size == 1
228
+ val_sampler = torch.utils.data.distributed.DistributedSampler(
229
+ val_dataset, shuffle=False, drop_last=False
230
+ )
231
  val_loader = torch.utils.data.DataLoader(
232
  val_dataset,
233
  batch_size=args.val_batch_size,
 
242
  best_score, cur_ciou = 0.0, 0.0
243
 
244
  if args.eval_only:
245
+ giou, ciou = validate(val_loader, model_engine, 0, writer, args)
 
 
246
  exit()
247
 
248
  for epoch in range(args.start_epoch, args.epochs):
 
249
  # train for one epoch
250
  train_iter = train(
251
  train_loader,
 
258
  )
259
 
260
  if args.no_eval == False:
261
+ giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
 
 
262
  is_best = giou > best_score
263
  best_score = max(giou, best_score)
264
  cur_ciou = ciou if is_best else cur_ciou
utils/dataset.py CHANGED
@@ -17,8 +17,12 @@ from .reason_seg_dataset import ReasonSegDataset
17
  from .refer import REFER
18
  from .refer_seg_dataset import ReferSegDataset
19
  from .sem_seg_dataset import SemSegDataset
20
- from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
21
- DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN)
 
 
 
 
22
  from .vqa_dataset import VQADataset
23
 
24
 
@@ -67,7 +71,7 @@ def collate_fn(batch, tokenizer=None):
67
  max_length=tokenizer.model_max_length,
68
  truncation=True,
69
  )
70
-
71
  input_ids = tokenize_data.input_ids
72
  attention_masks = tokenize_data.attention_mask
73
 
@@ -261,7 +265,7 @@ class ValDataset(torch.utils.data.Dataset):
261
  os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
262
  )
263
  self.images = images
264
- self.data_type = 'reason_seg'
265
  elif len(splits) == 3:
266
  ds, splitBy, split = splits
267
  refer_api = REFER(self.base_image_dir, ds, splitBy)
@@ -294,7 +298,7 @@ class ValDataset(torch.utils.data.Dataset):
294
  ]
295
  refer_seg_ds["img2refs"] = img2refs
296
  self.refer_seg_ds = refer_seg_ds
297
- self.data_type = 'refer_seg'
298
 
299
  self.ds = ds
300
  self.image_size = image_size
@@ -303,7 +307,7 @@ class ValDataset(torch.utils.data.Dataset):
303
  self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
304
 
305
  def __len__(self):
306
- if self.data_type == 'refer_seg':
307
  return len(self.refer_seg_ds["images"])
308
  else:
309
  return len(self.images)
@@ -321,7 +325,7 @@ class ValDataset(torch.utils.data.Dataset):
321
  return x
322
 
323
  def __getitem__(self, idx):
324
- if self.data_type == 'refer_seg':
325
  refer_seg_ds = self.refer_seg_ds
326
  images = refer_seg_ds["images"]
327
  annotations = refer_seg_ds["annotations"]
@@ -406,7 +410,7 @@ class ValDataset(torch.utils.data.Dataset):
406
 
407
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
408
 
409
- if self.data_type == 'refer_seg':
410
  masks = []
411
  for i, ann_id in enumerate(sampled_ann_ids):
412
  ann = annotations[ann_id]
 
17
  from .refer import REFER
18
  from .refer_seg_dataset import ReferSegDataset
19
  from .sem_seg_dataset import SemSegDataset
20
+ from .utils import (
21
+ DEFAULT_IM_END_TOKEN,
22
+ DEFAULT_IM_START_TOKEN,
23
+ DEFAULT_IMAGE_PATCH_TOKEN,
24
+ DEFAULT_IMAGE_TOKEN,
25
+ )
26
  from .vqa_dataset import VQADataset
27
 
28
 
 
71
  max_length=tokenizer.model_max_length,
72
  truncation=True,
73
  )
74
+
75
  input_ids = tokenize_data.input_ids
76
  attention_masks = tokenize_data.attention_mask
77
 
 
265
  os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
266
  )
267
  self.images = images
268
+ self.data_type = "reason_seg"
269
  elif len(splits) == 3:
270
  ds, splitBy, split = splits
271
  refer_api = REFER(self.base_image_dir, ds, splitBy)
 
298
  ]
299
  refer_seg_ds["img2refs"] = img2refs
300
  self.refer_seg_ds = refer_seg_ds
301
+ self.data_type = "refer_seg"
302
 
303
  self.ds = ds
304
  self.image_size = image_size
 
307
  self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
308
 
309
  def __len__(self):
310
+ if self.data_type == "refer_seg":
311
  return len(self.refer_seg_ds["images"])
312
  else:
313
  return len(self.images)
 
325
  return x
326
 
327
  def __getitem__(self, idx):
328
+ if self.data_type == "refer_seg":
329
  refer_seg_ds = self.refer_seg_ds
330
  images = refer_seg_ds["images"]
331
  annotations = refer_seg_ds["annotations"]
 
410
 
411
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
412
 
413
+ if self.data_type == "refer_seg":
414
  masks = []
415
  for i, ann_id in enumerate(sampled_ann_ids):
416
  ann = annotations[ann_id]
utils/reason_seg_dataset.py CHANGED
@@ -13,10 +13,16 @@ from model.segment_anything.utils.transforms import ResizeLongestSide
13
 
14
  from .conversation import get_default_conv_template
15
  from .data_processing import get_mask_from_json
16
- from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
17
- DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
18
- EXPLANATORY_QUESTION_LIST, LONG_QUESTION_LIST,
19
- SHORT_QUESTION_LIST)
 
 
 
 
 
 
20
 
21
 
22
  class ReasonSegDataset(torch.utils.data.Dataset):
@@ -72,7 +78,13 @@ class ReasonSegDataset(torch.utils.data.Dataset):
72
  self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
73
  self.img_to_explanation = {}
74
  with open(
75
- os.path.join(base_image_dir, "reason_seg", reason_seg_data, "explanatory", "train.json")
 
 
 
 
 
 
76
  ) as f:
77
  items = json.load(f)
78
  for item in items:
@@ -131,9 +143,7 @@ class ReasonSegDataset(torch.utils.data.Dataset):
131
  ]
132
 
133
  image_name = image_path.split("/")[-1]
134
- if (
135
- self.explanatory != -1 and image_name in self.img_to_explanation
136
- ):
137
  if random.random() < self.explanatory:
138
  choice = 2
139
  else:
@@ -200,7 +210,11 @@ class ReasonSegDataset(torch.utils.data.Dataset):
200
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
201
 
202
  image_name = image_path.split("/")[-1]
203
- if self.explanatory != -1 and image_name in self.img_to_explanation and choice == 2:
 
 
 
 
204
  masks = torch.rand(0, *ori_size)
205
  label = torch.ones(ori_size) * self.ignore_label
206
  else:
 
13
 
14
  from .conversation import get_default_conv_template
15
  from .data_processing import get_mask_from_json
16
+ from .utils import (
17
+ ANSWER_LIST,
18
+ DEFAULT_IM_END_TOKEN,
19
+ DEFAULT_IM_START_TOKEN,
20
+ DEFAULT_IMAGE_PATCH_TOKEN,
21
+ DEFAULT_IMAGE_TOKEN,
22
+ EXPLANATORY_QUESTION_LIST,
23
+ LONG_QUESTION_LIST,
24
+ SHORT_QUESTION_LIST,
25
+ )
26
 
27
 
28
  class ReasonSegDataset(torch.utils.data.Dataset):
 
78
  self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
79
  self.img_to_explanation = {}
80
  with open(
81
+ os.path.join(
82
+ base_image_dir,
83
+ "reason_seg",
84
+ reason_seg_data,
85
+ "explanatory",
86
+ "train.json",
87
+ )
88
  ) as f:
89
  items = json.load(f)
90
  for item in items:
 
143
  ]
144
 
145
  image_name = image_path.split("/")[-1]
146
+ if self.explanatory != -1 and image_name in self.img_to_explanation:
 
 
147
  if random.random() < self.explanatory:
148
  choice = 2
149
  else:
 
210
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
211
 
212
  image_name = image_path.split("/")[-1]
213
+ if (
214
+ self.explanatory != -1
215
+ and image_name in self.img_to_explanation
216
+ and choice == 2
217
+ ):
218
  masks = torch.rand(0, *ori_size)
219
  label = torch.ones(ori_size) * self.ignore_label
220
  else:
utils/refer_seg_dataset.py CHANGED
@@ -12,9 +12,14 @@ from model.segment_anything.utils.transforms import ResizeLongestSide
12
 
13
  from .conversation import get_default_conv_template
14
  from .refer import REFER
15
- from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
16
- DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
17
- SHORT_QUESTION_LIST)
 
 
 
 
 
18
 
19
 
20
  class ReferSegDataset(torch.utils.data.Dataset):
 
12
 
13
  from .conversation import get_default_conv_template
14
  from .refer import REFER
15
+ from .utils import (
16
+ ANSWER_LIST,
17
+ DEFAULT_IM_END_TOKEN,
18
+ DEFAULT_IM_START_TOKEN,
19
+ DEFAULT_IMAGE_PATCH_TOKEN,
20
+ DEFAULT_IMAGE_TOKEN,
21
+ SHORT_QUESTION_LIST,
22
+ )
23
 
24
 
25
  class ReferSegDataset(torch.utils.data.Dataset):
utils/sem_seg_dataset.py CHANGED
@@ -14,9 +14,15 @@ from transformers import CLIPImageProcessor
14
  from model.segment_anything.utils.transforms import ResizeLongestSide
15
 
16
  from .conversation import get_default_conv_template
17
- from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
18
- DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
19
- SHORT_QUESTION_LIST)
 
 
 
 
 
 
20
 
21
  def init_mapillary(base_image_dir):
22
  mapillary_data_root = os.path.join(base_image_dir, "mapillary")
 
14
  from model.segment_anything.utils.transforms import ResizeLongestSide
15
 
16
  from .conversation import get_default_conv_template
17
+ from .utils import (
18
+ ANSWER_LIST,
19
+ DEFAULT_IM_END_TOKEN,
20
+ DEFAULT_IM_START_TOKEN,
21
+ DEFAULT_IMAGE_PATCH_TOKEN,
22
+ DEFAULT_IMAGE_TOKEN,
23
+ SHORT_QUESTION_LIST,
24
+ )
25
+
26
 
27
  def init_mapillary(base_image_dir):
28
  mapillary_data_root = os.path.join(base_image_dir, "mapillary")
utils/utils.py CHANGED
@@ -12,8 +12,12 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
12
  SHORT_QUESTION_LIST = [
13
  DEFAULT_IMAGE_TOKEN + " " + "Can you segment the {class_name} in this image?",
14
  DEFAULT_IMAGE_TOKEN + " " + "Please segment the {class_name} in this image.",
15
- DEFAULT_IMAGE_TOKEN + " " + "What is {class_name} in this image? Please respond with segmentation mask.",
16
- DEFAULT_IMAGE_TOKEN + " " + "What is {class_name} in this image? Please output segmentation mask.",
 
 
 
 
17
  ]
18
 
19
  LONG_QUESTION_LIST = [
@@ -121,6 +125,7 @@ def intersectionAndUnionGPU(output, target, K, ignore_index=255):
121
  area_union = area_output + area_target - area_intersection
122
  return area_intersection, area_union, area_target
123
 
 
124
  class ProgressMeter(object):
125
  def __init__(self, num_batches, meters, prefix=""):
126
  self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
 
12
  SHORT_QUESTION_LIST = [
13
  DEFAULT_IMAGE_TOKEN + " " + "Can you segment the {class_name} in this image?",
14
  DEFAULT_IMAGE_TOKEN + " " + "Please segment the {class_name} in this image.",
15
+ DEFAULT_IMAGE_TOKEN
16
+ + " "
17
+ + "What is {class_name} in this image? Please respond with segmentation mask.",
18
+ DEFAULT_IMAGE_TOKEN
19
+ + " "
20
+ + "What is {class_name} in this image? Please output segmentation mask.",
21
  ]
22
 
23
  LONG_QUESTION_LIST = [
 
125
  area_union = area_output + area_target - area_intersection
126
  return area_intersection, area_union, area_target
127
 
128
+
129
  class ProgressMeter(object):
130
  def __init__(self, num_batches, meters, prefix=""):
131
  self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
utils/vqa_dataset.py CHANGED
@@ -10,8 +10,13 @@ from transformers import CLIPImageProcessor
10
  from model.segment_anything.utils.transforms import ResizeLongestSide
11
 
12
  from .conversation import get_default_conv_template
13
- from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
14
- DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN)
 
 
 
 
 
15
 
16
  class VQADataset(torch.utils.data.Dataset):
17
  pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
@@ -49,7 +54,7 @@ class VQADataset(torch.utils.data.Dataset):
49
  self.vqa_data = vqa_data
50
 
51
  print("vqa_data: ", len(self.vqa_data))
52
-
53
  def __len__(self):
54
  return self.samples_per_epoch
55
 
@@ -72,7 +77,11 @@ class VQADataset(torch.utils.data.Dataset):
72
  img = cv2.imread(image_path)
73
  images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
  ori_size = images.shape[:2]
75
- images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")["pixel_values"][0] # preprocess images for clip
 
 
 
 
76
  image_token_len = (images_clip.shape[1] // 14) * (
77
  images_clip.shape[2] // 14
78
  ) # FIXME: 14 is hardcoded patch size
 
10
  from model.segment_anything.utils.transforms import ResizeLongestSide
11
 
12
  from .conversation import get_default_conv_template
13
+ from .utils import (
14
+ DEFAULT_IM_END_TOKEN,
15
+ DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IMAGE_PATCH_TOKEN,
17
+ DEFAULT_IMAGE_TOKEN,
18
+ )
19
+
20
 
21
  class VQADataset(torch.utils.data.Dataset):
22
  pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
 
54
  self.vqa_data = vqa_data
55
 
56
  print("vqa_data: ", len(self.vqa_data))
57
+
58
  def __len__(self):
59
  return self.samples_per_epoch
60
 
 
77
  img = cv2.imread(image_path)
78
  images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
79
  ori_size = images.shape[:2]
80
+ images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
81
+ "pixel_values"
82
+ ][
83
+ 0
84
+ ] # preprocess images for clip
85
  image_token_len = (images_clip.shape[1] // 14) * (
86
  images_clip.shape[2] // 14
87
  ) # FIXME: 14 is hardcoded patch size