Spaces:
Paused
Paused
x-lai
commited on
Commit
·
6144294
1
Parent(s):
a46000f
Release training script
Browse filesFormer-commit-id: 96ec3cdf6a4f6880ac274ddde55537570788ebbb
- model/LISA.py +10 -4
- model/llava/eval/model_vqa.py +8 -3
- model/llava/eval/model_vqa_science.py +8 -3
- model/llava/eval/run_llava.py +7 -3
- model/llava/eval/run_llava_batch.py +7 -3
- model/llava/eval/run_llava_batch_v2.py +7 -3
- model/llava/eval/run_llava_batch_v3.py +7 -3
- model/llava/model/llava.py +13 -5
- model/llava/model/llava_mpt.py +10 -4
- model/llava/model/mpt/adapt_tokenizer.py +1 -2
- model/llava/model/mpt/hf_prefixlm_converter.py +14 -10
- model/llava/model/mpt/modeling_mpt.py +9 -6
- model/llava/serve/gradio_web_server.py +7 -4
- model/llava/train/train.py +3 -2
- model/llava/train/train_mem.py +1 -2
- model/segment_anything/__init__.py +7 -2
- model/segment_anything/automatic_mask_generator.py +19 -9
- model/segment_anything/build_sam.py +7 -2
- train_ds.py +24 -17
- utils/dataset.py +12 -8
- utils/reason_seg_dataset.py +23 -9
- utils/refer_seg_dataset.py +8 -3
- utils/sem_seg_dataset.py +9 -3
- utils/utils.py +7 -2
- utils/vqa_dataset.py +13 -4
model/LISA.py
CHANGED
@@ -3,14 +3,18 @@ from typing import List
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
-
from peft import
|
7 |
from transformers import BitsAndBytesConfig, CLIPVisionModel
|
8 |
|
9 |
from transformers import CLIPVisionModel, BitsAndBytesConfig
|
10 |
from .llava.model.llava import LlavaLlamaForCausalLM
|
11 |
from .segment_anything import build_sam_vit_h
|
12 |
-
from utils.utils import (
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def dice_loss(
|
16 |
inputs: torch.Tensor,
|
@@ -219,7 +223,9 @@ class LISA(nn.Module):
|
|
219 |
self.lm.resize_token_embeddings(len(tokenizer))
|
220 |
|
221 |
for n, p in self.lm.named_parameters():
|
222 |
-
if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(
|
|
|
|
|
223 |
p.requires_grad = True
|
224 |
|
225 |
# SAM
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
from transformers import BitsAndBytesConfig, CLIPVisionModel
|
8 |
|
9 |
from transformers import CLIPVisionModel, BitsAndBytesConfig
|
10 |
from .llava.model.llava import LlavaLlamaForCausalLM
|
11 |
from .segment_anything import build_sam_vit_h
|
12 |
+
from utils.utils import (
|
13 |
+
DEFAULT_IM_END_TOKEN,
|
14 |
+
DEFAULT_IM_START_TOKEN,
|
15 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
16 |
+
)
|
17 |
+
|
18 |
|
19 |
def dice_loss(
|
20 |
inputs: torch.Tensor,
|
|
|
223 |
self.lm.resize_token_embeddings(len(tokenizer))
|
224 |
|
225 |
for n, p in self.lm.named_parameters():
|
226 |
+
if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(
|
227 |
+
tokenizer
|
228 |
+
):
|
229 |
p.requires_grad = True
|
230 |
|
231 |
# SAM
|
model/llava/eval/model_vqa.py
CHANGED
@@ -11,9 +11,14 @@ from llava.conversation import conv_templates
|
|
11 |
from llava.utils import disable_torch_init
|
12 |
from PIL import Image
|
13 |
from tqdm import tqdm
|
14 |
-
from transformers import (
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def split_list(lst, n):
|
|
|
11 |
from llava.utils import disable_torch_init
|
12 |
from PIL import Image
|
13 |
from tqdm import tqdm
|
14 |
+
from transformers import (
|
15 |
+
AutoConfig,
|
16 |
+
AutoModelForCausalLM,
|
17 |
+
AutoTokenizer,
|
18 |
+
CLIPImageProcessor,
|
19 |
+
CLIPVisionModel,
|
20 |
+
StoppingCriteria,
|
21 |
+
)
|
22 |
|
23 |
|
24 |
def split_list(lst, n):
|
model/llava/eval/model_vqa_science.py
CHANGED
@@ -11,9 +11,14 @@ from llava.conversation import conv_templates
|
|
11 |
from llava.utils import disable_torch_init
|
12 |
from PIL import Image
|
13 |
from tqdm import tqdm
|
14 |
-
from transformers import (
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def split_list(lst, n):
|
|
|
11 |
from llava.utils import disable_torch_init
|
12 |
from PIL import Image
|
13 |
from tqdm import tqdm
|
14 |
+
from transformers import (
|
15 |
+
AutoConfig,
|
16 |
+
AutoModelForCausalLM,
|
17 |
+
AutoTokenizer,
|
18 |
+
CLIPImageProcessor,
|
19 |
+
CLIPVisionModel,
|
20 |
+
StoppingCriteria,
|
21 |
+
)
|
22 |
|
23 |
|
24 |
def split_list(lst, n):
|
model/llava/eval/run_llava.py
CHANGED
@@ -9,9 +9,13 @@ from llava.model import *
|
|
9 |
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
from llava.utils import disable_torch_init
|
11 |
from PIL import Image
|
12 |
-
from transformers import (
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
17 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
9 |
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
from llava.utils import disable_torch_init
|
11 |
from PIL import Image
|
12 |
+
from transformers import (
|
13 |
+
AutoModelForCausalLM,
|
14 |
+
AutoTokenizer,
|
15 |
+
CLIPImageProcessor,
|
16 |
+
CLIPVisionModel,
|
17 |
+
StoppingCriteria,
|
18 |
+
)
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
model/llava/eval/run_llava_batch.py
CHANGED
@@ -13,9 +13,13 @@ from llava.model import *
|
|
13 |
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
from llava.utils import disable_torch_init
|
15 |
from PIL import Image
|
16 |
-
from transformers import (
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
13 |
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
from llava.utils import disable_torch_init
|
15 |
from PIL import Image
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
AutoTokenizer,
|
19 |
+
CLIPImageProcessor,
|
20 |
+
CLIPVisionModel,
|
21 |
+
StoppingCriteria,
|
22 |
+
)
|
23 |
|
24 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
25 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
model/llava/eval/run_llava_batch_v2.py
CHANGED
@@ -13,9 +13,13 @@ from llava.model import *
|
|
13 |
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
from llava.utils import disable_torch_init
|
15 |
from PIL import Image
|
16 |
-
from transformers import (
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
13 |
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
from llava.utils import disable_torch_init
|
15 |
from PIL import Image
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
AutoTokenizer,
|
19 |
+
CLIPImageProcessor,
|
20 |
+
CLIPVisionModel,
|
21 |
+
StoppingCriteria,
|
22 |
+
)
|
23 |
|
24 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
25 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
model/llava/eval/run_llava_batch_v3.py
CHANGED
@@ -13,9 +13,13 @@ from llava.model import *
|
|
13 |
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
from llava.utils import disable_torch_init
|
15 |
from PIL import Image
|
16 |
-
from transformers import (
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
13 |
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
from llava.utils import disable_torch_init
|
15 |
from PIL import Image
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
AutoTokenizer,
|
19 |
+
CLIPImageProcessor,
|
20 |
+
CLIPVisionModel,
|
21 |
+
StoppingCriteria,
|
22 |
+
)
|
23 |
|
24 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
25 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
model/llava/model/llava.py
CHANGED
@@ -19,11 +19,19 @@ import torch
|
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
from torch.nn import CrossEntropyLoss
|
22 |
-
from transformers import (
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
29 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
from torch.nn import CrossEntropyLoss
|
22 |
+
from transformers import (
|
23 |
+
AutoConfig,
|
24 |
+
AutoModelForCausalLM,
|
25 |
+
CLIPImageProcessor,
|
26 |
+
CLIPVisionModel,
|
27 |
+
LlamaConfig,
|
28 |
+
LlamaForCausalLM,
|
29 |
+
LlamaModel,
|
30 |
+
)
|
31 |
+
from transformers.modeling_outputs import (
|
32 |
+
BaseModelOutputWithPast,
|
33 |
+
CausalLMOutputWithPast,
|
34 |
+
)
|
35 |
|
36 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
37 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
model/llava/model/llava_mpt.py
CHANGED
@@ -21,10 +21,16 @@ import torch
|
|
21 |
import torch.nn as nn
|
22 |
import torch.nn.functional as F
|
23 |
from torch.nn import CrossEntropyLoss
|
24 |
-
from transformers import (
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
30 |
|
|
|
21 |
import torch.nn as nn
|
22 |
import torch.nn.functional as F
|
23 |
from torch.nn import CrossEntropyLoss
|
24 |
+
from transformers import (
|
25 |
+
AutoConfig,
|
26 |
+
AutoModelForCausalLM,
|
27 |
+
CLIPImageProcessor,
|
28 |
+
CLIPVisionModel,
|
29 |
+
)
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BaseModelOutputWithPast,
|
32 |
+
CausalLMOutputWithPast,
|
33 |
+
)
|
34 |
|
35 |
from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
36 |
|
model/llava/model/mpt/adapt_tokenizer.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from typing import Union
|
2 |
|
3 |
-
from transformers import
|
4 |
-
PreTrainedTokenizerFast)
|
5 |
|
6 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
7 |
NUM_SENTINEL_TOKENS: int = 100
|
|
|
1 |
from typing import Union
|
2 |
|
3 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
|
4 |
|
5 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
6 |
NUM_SENTINEL_TOKENS: int = 100
|
model/llava/model/mpt/hf_prefixlm_converter.py
CHANGED
@@ -13,22 +13,26 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13 |
|
14 |
import torch
|
15 |
from transformers.models.bloom.modeling_bloom import (
|
16 |
-
BaseModelOutputWithPastAndCrossAttentions,
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
from transformers.models.bloom.modeling_bloom import logging
|
23 |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
24 |
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
25 |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
26 |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
27 |
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
28 |
-
from transformers.models.opt.modeling_opt import
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
logger = logging.get_logger(__name__)
|
34 |
_SUPPORTED_GPT_MODELS = (
|
|
|
13 |
|
14 |
import torch
|
15 |
from transformers.models.bloom.modeling_bloom import (
|
16 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
17 |
+
BloomForCausalLM,
|
18 |
+
BloomModel,
|
19 |
+
CausalLMOutputWithCrossAttentions,
|
20 |
+
CrossEntropyLoss,
|
21 |
+
)
|
22 |
+
from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
|
23 |
+
from transformers.models.bloom.modeling_bloom import (
|
24 |
+
_make_causal_mask as _make_causal_mask_bloom,
|
25 |
+
)
|
26 |
from transformers.models.bloom.modeling_bloom import logging
|
27 |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
28 |
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
29 |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
30 |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
31 |
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
32 |
+
from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
|
33 |
+
from transformers.models.opt.modeling_opt import (
|
34 |
+
_make_causal_mask as _make_causal_mask_opt,
|
35 |
+
)
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
_SUPPORTED_GPT_MODELS = (
|
model/llava/model/mpt/modeling_mpt.py
CHANGED
@@ -9,17 +9,20 @@ from typing import List, Optional, Tuple, Union
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
-
from transformers import
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
16 |
|
17 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
18 |
from .attention import attn_bias_shape, build_attn_bias
|
19 |
from .blocks import MPTBlock
|
20 |
from .configuration_mpt import MPTConfig
|
21 |
-
from .hf_prefixlm_converter import (
|
22 |
-
|
|
|
|
|
23 |
from .meta_init_context import init_empty_weights
|
24 |
from .norm import NORM_CLASS_REGISTRY
|
25 |
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
13 |
+
from transformers.modeling_outputs import (
|
14 |
+
BaseModelOutputWithPast,
|
15 |
+
CausalLMOutputWithPast,
|
16 |
+
)
|
17 |
|
18 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
19 |
from .attention import attn_bias_shape, build_attn_bias
|
20 |
from .blocks import MPTBlock
|
21 |
from .configuration_mpt import MPTConfig
|
22 |
+
from .hf_prefixlm_converter import (
|
23 |
+
add_bidirectional_mask_if_missing,
|
24 |
+
convert_hf_causal_lm_to_prefix_lm,
|
25 |
+
)
|
26 |
from .meta_init_context import init_empty_weights
|
27 |
from .norm import NORM_CLASS_REGISTRY
|
28 |
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
model/llava/serve/gradio_web_server.py
CHANGED
@@ -9,12 +9,15 @@ from collections import defaultdict
|
|
9 |
import gradio as gr
|
10 |
import requests
|
11 |
from llava.constants import LOGDIR
|
12 |
-
from llava.conversation import
|
13 |
-
default_conversation)
|
14 |
from llava.serve.gradio_css import code_highlight_css
|
15 |
from llava.serve.gradio_patch import Chatbot as grChatbot
|
16 |
-
from llava.utils import (
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
20 |
|
|
|
9 |
import gradio as gr
|
10 |
import requests
|
11 |
from llava.constants import LOGDIR
|
12 |
+
from llava.conversation import SeparatorStyle, conv_templates, default_conversation
|
|
|
13 |
from llava.serve.gradio_css import code_highlight_css
|
14 |
from llava.serve.gradio_patch import Chatbot as grChatbot
|
15 |
+
from llava.utils import (
|
16 |
+
build_logger,
|
17 |
+
moderation_msg,
|
18 |
+
server_error_msg,
|
19 |
+
violates_moderation,
|
20 |
+
)
|
21 |
|
22 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
23 |
|
model/llava/train/train.py
CHANGED
@@ -715,8 +715,9 @@ def train():
|
|
715 |
"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
|
716 |
)
|
717 |
|
718 |
-
from torch.distributed.fsdp.fully_sharded_data_parallel import
|
719 |
-
FullyShardedDataParallel as FSDP
|
|
|
720 |
|
721 |
def patch_FSDP_use_orig_params(func):
|
722 |
def wrap_func(*args, **kwargs):
|
|
|
715 |
"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
|
716 |
)
|
717 |
|
718 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
719 |
+
FullyShardedDataParallel as FSDP,
|
720 |
+
)
|
721 |
|
722 |
def patch_FSDP_use_orig_params(func):
|
723 |
def wrap_func(*args, **kwargs):
|
model/llava/train/train_mem.py
CHANGED
@@ -3,8 +3,7 @@
|
|
3 |
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
4 |
|
5 |
# Need to call this before importing transformers.
|
6 |
-
from llava.train.llama_flash_attn_monkey_patch import
|
7 |
-
replace_llama_attn_with_flash_attn
|
8 |
|
9 |
replace_llama_attn_with_flash_attn()
|
10 |
|
|
|
3 |
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
4 |
|
5 |
# Need to call this before importing transformers.
|
6 |
+
from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
|
|
7 |
|
8 |
replace_llama_attn_with_flash_attn()
|
9 |
|
model/segment_anything/__init__.py
CHANGED
@@ -5,6 +5,11 @@
|
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
8 |
-
from .build_sam import (
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
from .predictor import SamPredictor
|
|
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
8 |
+
from .build_sam import (
|
9 |
+
build_sam,
|
10 |
+
build_sam_vit_b,
|
11 |
+
build_sam_vit_h,
|
12 |
+
build_sam_vit_l,
|
13 |
+
sam_model_registry,
|
14 |
+
)
|
15 |
from .predictor import SamPredictor
|
model/segment_anything/automatic_mask_generator.py
CHANGED
@@ -12,13 +12,24 @@ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
|
12 |
|
13 |
from .modeling import Sam
|
14 |
from .predictor import SamPredictor
|
15 |
-
from .utils.amg import (
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
class SamAutomaticMaskGenerator:
|
@@ -104,8 +115,7 @@ class SamAutomaticMaskGenerator:
|
|
104 |
"coco_rle",
|
105 |
], f"Unknown output_mode {output_mode}."
|
106 |
if output_mode == "coco_rle":
|
107 |
-
from pycocotools import
|
108 |
-
mask as mask_utils # type: ignore # noqa: F401
|
109 |
|
110 |
if min_mask_region_area > 0:
|
111 |
import cv2 # type: ignore # noqa: F401
|
|
|
12 |
|
13 |
from .modeling import Sam
|
14 |
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (
|
16 |
+
MaskData,
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
remove_small_regions,
|
28 |
+
rle_to_mask,
|
29 |
+
uncrop_boxes_xyxy,
|
30 |
+
uncrop_masks,
|
31 |
+
uncrop_points,
|
32 |
+
)
|
33 |
|
34 |
|
35 |
class SamAutomaticMaskGenerator:
|
|
|
115 |
"coco_rle",
|
116 |
], f"Unknown output_mode {output_mode}."
|
117 |
if output_mode == "coco_rle":
|
118 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
|
|
119 |
|
120 |
if min_mask_region_area > 0:
|
121 |
import cv2 # type: ignore # noqa: F401
|
model/segment_anything/build_sam.py
CHANGED
@@ -8,8 +8,13 @@ from functools import partial
|
|
8 |
|
9 |
import torch
|
10 |
|
11 |
-
from .modeling import (
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def build_sam_vit_h(checkpoint=None):
|
|
|
8 |
|
9 |
import torch
|
10 |
|
11 |
+
from .modeling import (
|
12 |
+
ImageEncoderViT,
|
13 |
+
MaskDecoder,
|
14 |
+
PromptEncoder,
|
15 |
+
Sam,
|
16 |
+
TwoWayTransformer,
|
17 |
+
)
|
18 |
|
19 |
|
20 |
def build_sam_vit_h(checkpoint=None):
|
train_ds.py
CHANGED
@@ -14,8 +14,13 @@ from torch.utils.tensorboard import SummaryWriter
|
|
14 |
|
15 |
from model.LISA import LISA
|
16 |
from utils.dataset import HybridDataset, ValDataset, collate_fn
|
17 |
-
from utils.utils import (
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def parse_args(args):
|
@@ -54,9 +59,7 @@ def parse_args(args):
|
|
54 |
)
|
55 |
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
|
56 |
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
|
57 |
-
parser.add_argument(
|
58 |
-
"--val_dataset", default="ReasonSeg|val", type=str
|
59 |
-
)
|
60 |
parser.add_argument("--dataset_dir", default="./dataset", type=str)
|
61 |
parser.add_argument("--log_base_dir", default="./runs", type=str)
|
62 |
parser.add_argument("--exp_name", default="lisa", type=str)
|
@@ -87,7 +90,9 @@ def parse_args(args):
|
|
87 |
parser.add_argument("--exclude_val", action="store_true", default=False)
|
88 |
parser.add_argument("--no_eval", action="store_true", default=False)
|
89 |
parser.add_argument("--eval_only", action="store_true", default=False)
|
90 |
-
parser.add_argument(
|
|
|
|
|
91 |
parser.add_argument("--weight", default="", type=str)
|
92 |
parser.add_argument("--print_freq", default=1, type=int)
|
93 |
parser.add_argument("--start_epoch", default=0, type=int)
|
@@ -133,7 +138,7 @@ def main(args):
|
|
133 |
)
|
134 |
|
135 |
if args.weight:
|
136 |
-
state_dict = torch.load(args.weight, map_location=
|
137 |
model.load_state_dict(state_dict, strict=True)
|
138 |
|
139 |
world_size = torch.cuda.device_count()
|
@@ -142,7 +147,10 @@ def main(args):
|
|
142 |
args.dataset_dir,
|
143 |
tokenizer,
|
144 |
args.vision_tower,
|
145 |
-
samples_per_epoch=args.batch_size
|
|
|
|
|
|
|
146 |
precision=args.precision,
|
147 |
image_size=args.image_size,
|
148 |
num_classes_per_sample=args.num_classes_per_sample,
|
@@ -163,7 +171,9 @@ def main(args):
|
|
163 |
args.val_dataset,
|
164 |
args.image_size,
|
165 |
)
|
166 |
-
print(
|
|
|
|
|
167 |
else:
|
168 |
val_dataset = None
|
169 |
print(f"Training with {len(train_dataset)} examples.")
|
@@ -215,7 +225,9 @@ def main(args):
|
|
215 |
|
216 |
if val_dataset is not None:
|
217 |
assert args.val_batch_size == 1
|
218 |
-
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
|
|
|
219 |
val_loader = torch.utils.data.DataLoader(
|
220 |
val_dataset,
|
221 |
batch_size=args.val_batch_size,
|
@@ -230,13 +242,10 @@ def main(args):
|
|
230 |
best_score, cur_ciou = 0.0, 0.0
|
231 |
|
232 |
if args.eval_only:
|
233 |
-
giou, ciou = validate(
|
234 |
-
val_loader, model_engine, 0, writer, args
|
235 |
-
)
|
236 |
exit()
|
237 |
|
238 |
for epoch in range(args.start_epoch, args.epochs):
|
239 |
-
|
240 |
# train for one epoch
|
241 |
train_iter = train(
|
242 |
train_loader,
|
@@ -249,9 +258,7 @@ def main(args):
|
|
249 |
)
|
250 |
|
251 |
if args.no_eval == False:
|
252 |
-
giou, ciou = validate(
|
253 |
-
val_loader, model_engine, epoch, writer, args
|
254 |
-
)
|
255 |
is_best = giou > best_score
|
256 |
best_score = max(giou, best_score)
|
257 |
cur_ciou = ciou if is_best else cur_ciou
|
|
|
14 |
|
15 |
from model.LISA import LISA
|
16 |
from utils.dataset import HybridDataset, ValDataset, collate_fn
|
17 |
+
from utils.utils import (
|
18 |
+
AverageMeter,
|
19 |
+
ProgressMeter,
|
20 |
+
Summary,
|
21 |
+
dict_to_cuda,
|
22 |
+
intersectionAndUnionGPU,
|
23 |
+
)
|
24 |
|
25 |
|
26 |
def parse_args(args):
|
|
|
59 |
)
|
60 |
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
|
61 |
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
|
62 |
+
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
|
|
|
|
|
63 |
parser.add_argument("--dataset_dir", default="./dataset", type=str)
|
64 |
parser.add_argument("--log_base_dir", default="./runs", type=str)
|
65 |
parser.add_argument("--exp_name", default="lisa", type=str)
|
|
|
90 |
parser.add_argument("--exclude_val", action="store_true", default=False)
|
91 |
parser.add_argument("--no_eval", action="store_true", default=False)
|
92 |
parser.add_argument("--eval_only", action="store_true", default=False)
|
93 |
+
parser.add_argument(
|
94 |
+
"--vision_pretrained", default="PATH TO SAM ViT-H Pre-trained Wegiht", type=str
|
95 |
+
)
|
96 |
parser.add_argument("--weight", default="", type=str)
|
97 |
parser.add_argument("--print_freq", default=1, type=int)
|
98 |
parser.add_argument("--start_epoch", default=0, type=int)
|
|
|
138 |
)
|
139 |
|
140 |
if args.weight:
|
141 |
+
state_dict = torch.load(args.weight, map_location="cpu")
|
142 |
model.load_state_dict(state_dict, strict=True)
|
143 |
|
144 |
world_size = torch.cuda.device_count()
|
|
|
147 |
args.dataset_dir,
|
148 |
tokenizer,
|
149 |
args.vision_tower,
|
150 |
+
samples_per_epoch=args.batch_size
|
151 |
+
* args.grad_accumulation_steps
|
152 |
+
* args.steps_per_epoch
|
153 |
+
* world_size,
|
154 |
precision=args.precision,
|
155 |
image_size=args.image_size,
|
156 |
num_classes_per_sample=args.num_classes_per_sample,
|
|
|
171 |
args.val_dataset,
|
172 |
args.image_size,
|
173 |
)
|
174 |
+
print(
|
175 |
+
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
|
176 |
+
)
|
177 |
else:
|
178 |
val_dataset = None
|
179 |
print(f"Training with {len(train_dataset)} examples.")
|
|
|
225 |
|
226 |
if val_dataset is not None:
|
227 |
assert args.val_batch_size == 1
|
228 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
229 |
+
val_dataset, shuffle=False, drop_last=False
|
230 |
+
)
|
231 |
val_loader = torch.utils.data.DataLoader(
|
232 |
val_dataset,
|
233 |
batch_size=args.val_batch_size,
|
|
|
242 |
best_score, cur_ciou = 0.0, 0.0
|
243 |
|
244 |
if args.eval_only:
|
245 |
+
giou, ciou = validate(val_loader, model_engine, 0, writer, args)
|
|
|
|
|
246 |
exit()
|
247 |
|
248 |
for epoch in range(args.start_epoch, args.epochs):
|
|
|
249 |
# train for one epoch
|
250 |
train_iter = train(
|
251 |
train_loader,
|
|
|
258 |
)
|
259 |
|
260 |
if args.no_eval == False:
|
261 |
+
giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
|
|
|
|
|
262 |
is_best = giou > best_score
|
263 |
best_score = max(giou, best_score)
|
264 |
cur_ciou = ciou if is_best else cur_ciou
|
utils/dataset.py
CHANGED
@@ -17,8 +17,12 @@ from .reason_seg_dataset import ReasonSegDataset
|
|
17 |
from .refer import REFER
|
18 |
from .refer_seg_dataset import ReferSegDataset
|
19 |
from .sem_seg_dataset import SemSegDataset
|
20 |
-
from .utils import (
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
from .vqa_dataset import VQADataset
|
23 |
|
24 |
|
@@ -67,7 +71,7 @@ def collate_fn(batch, tokenizer=None):
|
|
67 |
max_length=tokenizer.model_max_length,
|
68 |
truncation=True,
|
69 |
)
|
70 |
-
|
71 |
input_ids = tokenize_data.input_ids
|
72 |
attention_masks = tokenize_data.attention_mask
|
73 |
|
@@ -261,7 +265,7 @@ class ValDataset(torch.utils.data.Dataset):
|
|
261 |
os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
|
262 |
)
|
263 |
self.images = images
|
264 |
-
self.data_type =
|
265 |
elif len(splits) == 3:
|
266 |
ds, splitBy, split = splits
|
267 |
refer_api = REFER(self.base_image_dir, ds, splitBy)
|
@@ -294,7 +298,7 @@ class ValDataset(torch.utils.data.Dataset):
|
|
294 |
]
|
295 |
refer_seg_ds["img2refs"] = img2refs
|
296 |
self.refer_seg_ds = refer_seg_ds
|
297 |
-
self.data_type =
|
298 |
|
299 |
self.ds = ds
|
300 |
self.image_size = image_size
|
@@ -303,7 +307,7 @@ class ValDataset(torch.utils.data.Dataset):
|
|
303 |
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
304 |
|
305 |
def __len__(self):
|
306 |
-
if self.data_type ==
|
307 |
return len(self.refer_seg_ds["images"])
|
308 |
else:
|
309 |
return len(self.images)
|
@@ -321,7 +325,7 @@ class ValDataset(torch.utils.data.Dataset):
|
|
321 |
return x
|
322 |
|
323 |
def __getitem__(self, idx):
|
324 |
-
if self.data_type ==
|
325 |
refer_seg_ds = self.refer_seg_ds
|
326 |
images = refer_seg_ds["images"]
|
327 |
annotations = refer_seg_ds["annotations"]
|
@@ -406,7 +410,7 @@ class ValDataset(torch.utils.data.Dataset):
|
|
406 |
|
407 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
408 |
|
409 |
-
if self.data_type ==
|
410 |
masks = []
|
411 |
for i, ann_id in enumerate(sampled_ann_ids):
|
412 |
ann = annotations[ann_id]
|
|
|
17 |
from .refer import REFER
|
18 |
from .refer_seg_dataset import ReferSegDataset
|
19 |
from .sem_seg_dataset import SemSegDataset
|
20 |
+
from .utils import (
|
21 |
+
DEFAULT_IM_END_TOKEN,
|
22 |
+
DEFAULT_IM_START_TOKEN,
|
23 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
24 |
+
DEFAULT_IMAGE_TOKEN,
|
25 |
+
)
|
26 |
from .vqa_dataset import VQADataset
|
27 |
|
28 |
|
|
|
71 |
max_length=tokenizer.model_max_length,
|
72 |
truncation=True,
|
73 |
)
|
74 |
+
|
75 |
input_ids = tokenize_data.input_ids
|
76 |
attention_masks = tokenize_data.attention_mask
|
77 |
|
|
|
265 |
os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
|
266 |
)
|
267 |
self.images = images
|
268 |
+
self.data_type = "reason_seg"
|
269 |
elif len(splits) == 3:
|
270 |
ds, splitBy, split = splits
|
271 |
refer_api = REFER(self.base_image_dir, ds, splitBy)
|
|
|
298 |
]
|
299 |
refer_seg_ds["img2refs"] = img2refs
|
300 |
self.refer_seg_ds = refer_seg_ds
|
301 |
+
self.data_type = "refer_seg"
|
302 |
|
303 |
self.ds = ds
|
304 |
self.image_size = image_size
|
|
|
307 |
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
308 |
|
309 |
def __len__(self):
|
310 |
+
if self.data_type == "refer_seg":
|
311 |
return len(self.refer_seg_ds["images"])
|
312 |
else:
|
313 |
return len(self.images)
|
|
|
325 |
return x
|
326 |
|
327 |
def __getitem__(self, idx):
|
328 |
+
if self.data_type == "refer_seg":
|
329 |
refer_seg_ds = self.refer_seg_ds
|
330 |
images = refer_seg_ds["images"]
|
331 |
annotations = refer_seg_ds["annotations"]
|
|
|
410 |
|
411 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
412 |
|
413 |
+
if self.data_type == "refer_seg":
|
414 |
masks = []
|
415 |
for i, ann_id in enumerate(sampled_ann_ids):
|
416 |
ann = annotations[ann_id]
|
utils/reason_seg_dataset.py
CHANGED
@@ -13,10 +13,16 @@ from model.segment_anything.utils.transforms import ResizeLongestSide
|
|
13 |
|
14 |
from .conversation import get_default_conv_template
|
15 |
from .data_processing import get_mask_from_json
|
16 |
-
from .utils import (
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
class ReasonSegDataset(torch.utils.data.Dataset):
|
@@ -72,7 +78,13 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
72 |
self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
|
73 |
self.img_to_explanation = {}
|
74 |
with open(
|
75 |
-
os.path.join(
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
) as f:
|
77 |
items = json.load(f)
|
78 |
for item in items:
|
@@ -131,9 +143,7 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
131 |
]
|
132 |
|
133 |
image_name = image_path.split("/")[-1]
|
134 |
-
if
|
135 |
-
self.explanatory != -1 and image_name in self.img_to_explanation
|
136 |
-
):
|
137 |
if random.random() < self.explanatory:
|
138 |
choice = 2
|
139 |
else:
|
@@ -200,7 +210,11 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
200 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
201 |
|
202 |
image_name = image_path.split("/")[-1]
|
203 |
-
if
|
|
|
|
|
|
|
|
|
204 |
masks = torch.rand(0, *ori_size)
|
205 |
label = torch.ones(ori_size) * self.ignore_label
|
206 |
else:
|
|
|
13 |
|
14 |
from .conversation import get_default_conv_template
|
15 |
from .data_processing import get_mask_from_json
|
16 |
+
from .utils import (
|
17 |
+
ANSWER_LIST,
|
18 |
+
DEFAULT_IM_END_TOKEN,
|
19 |
+
DEFAULT_IM_START_TOKEN,
|
20 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
21 |
+
DEFAULT_IMAGE_TOKEN,
|
22 |
+
EXPLANATORY_QUESTION_LIST,
|
23 |
+
LONG_QUESTION_LIST,
|
24 |
+
SHORT_QUESTION_LIST,
|
25 |
+
)
|
26 |
|
27 |
|
28 |
class ReasonSegDataset(torch.utils.data.Dataset):
|
|
|
78 |
self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
|
79 |
self.img_to_explanation = {}
|
80 |
with open(
|
81 |
+
os.path.join(
|
82 |
+
base_image_dir,
|
83 |
+
"reason_seg",
|
84 |
+
reason_seg_data,
|
85 |
+
"explanatory",
|
86 |
+
"train.json",
|
87 |
+
)
|
88 |
) as f:
|
89 |
items = json.load(f)
|
90 |
for item in items:
|
|
|
143 |
]
|
144 |
|
145 |
image_name = image_path.split("/")[-1]
|
146 |
+
if self.explanatory != -1 and image_name in self.img_to_explanation:
|
|
|
|
|
147 |
if random.random() < self.explanatory:
|
148 |
choice = 2
|
149 |
else:
|
|
|
210 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
211 |
|
212 |
image_name = image_path.split("/")[-1]
|
213 |
+
if (
|
214 |
+
self.explanatory != -1
|
215 |
+
and image_name in self.img_to_explanation
|
216 |
+
and choice == 2
|
217 |
+
):
|
218 |
masks = torch.rand(0, *ori_size)
|
219 |
label = torch.ones(ori_size) * self.ignore_label
|
220 |
else:
|
utils/refer_seg_dataset.py
CHANGED
@@ -12,9 +12,14 @@ from model.segment_anything.utils.transforms import ResizeLongestSide
|
|
12 |
|
13 |
from .conversation import get_default_conv_template
|
14 |
from .refer import REFER
|
15 |
-
from .utils import (
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class ReferSegDataset(torch.utils.data.Dataset):
|
|
|
12 |
|
13 |
from .conversation import get_default_conv_template
|
14 |
from .refer import REFER
|
15 |
+
from .utils import (
|
16 |
+
ANSWER_LIST,
|
17 |
+
DEFAULT_IM_END_TOKEN,
|
18 |
+
DEFAULT_IM_START_TOKEN,
|
19 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
20 |
+
DEFAULT_IMAGE_TOKEN,
|
21 |
+
SHORT_QUESTION_LIST,
|
22 |
+
)
|
23 |
|
24 |
|
25 |
class ReferSegDataset(torch.utils.data.Dataset):
|
utils/sem_seg_dataset.py
CHANGED
@@ -14,9 +14,15 @@ from transformers import CLIPImageProcessor
|
|
14 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
15 |
|
16 |
from .conversation import get_default_conv_template
|
17 |
-
from .utils import (
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def init_mapillary(base_image_dir):
|
22 |
mapillary_data_root = os.path.join(base_image_dir, "mapillary")
|
|
|
14 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
15 |
|
16 |
from .conversation import get_default_conv_template
|
17 |
+
from .utils import (
|
18 |
+
ANSWER_LIST,
|
19 |
+
DEFAULT_IM_END_TOKEN,
|
20 |
+
DEFAULT_IM_START_TOKEN,
|
21 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
22 |
+
DEFAULT_IMAGE_TOKEN,
|
23 |
+
SHORT_QUESTION_LIST,
|
24 |
+
)
|
25 |
+
|
26 |
|
27 |
def init_mapillary(base_image_dir):
|
28 |
mapillary_data_root = os.path.join(base_image_dir, "mapillary")
|
utils/utils.py
CHANGED
@@ -12,8 +12,12 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
12 |
SHORT_QUESTION_LIST = [
|
13 |
DEFAULT_IMAGE_TOKEN + " " + "Can you segment the {class_name} in this image?",
|
14 |
DEFAULT_IMAGE_TOKEN + " " + "Please segment the {class_name} in this image.",
|
15 |
-
DEFAULT_IMAGE_TOKEN
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
]
|
18 |
|
19 |
LONG_QUESTION_LIST = [
|
@@ -121,6 +125,7 @@ def intersectionAndUnionGPU(output, target, K, ignore_index=255):
|
|
121 |
area_union = area_output + area_target - area_intersection
|
122 |
return area_intersection, area_union, area_target
|
123 |
|
|
|
124 |
class ProgressMeter(object):
|
125 |
def __init__(self, num_batches, meters, prefix=""):
|
126 |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
|
|
12 |
SHORT_QUESTION_LIST = [
|
13 |
DEFAULT_IMAGE_TOKEN + " " + "Can you segment the {class_name} in this image?",
|
14 |
DEFAULT_IMAGE_TOKEN + " " + "Please segment the {class_name} in this image.",
|
15 |
+
DEFAULT_IMAGE_TOKEN
|
16 |
+
+ " "
|
17 |
+
+ "What is {class_name} in this image? Please respond with segmentation mask.",
|
18 |
+
DEFAULT_IMAGE_TOKEN
|
19 |
+
+ " "
|
20 |
+
+ "What is {class_name} in this image? Please output segmentation mask.",
|
21 |
]
|
22 |
|
23 |
LONG_QUESTION_LIST = [
|
|
|
125 |
area_union = area_output + area_target - area_intersection
|
126 |
return area_intersection, area_union, area_target
|
127 |
|
128 |
+
|
129 |
class ProgressMeter(object):
|
130 |
def __init__(self, num_batches, meters, prefix=""):
|
131 |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
utils/vqa_dataset.py
CHANGED
@@ -10,8 +10,13 @@ from transformers import CLIPImageProcessor
|
|
10 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
11 |
|
12 |
from .conversation import get_default_conv_template
|
13 |
-
from .utils import (
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class VQADataset(torch.utils.data.Dataset):
|
17 |
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
@@ -49,7 +54,7 @@ class VQADataset(torch.utils.data.Dataset):
|
|
49 |
self.vqa_data = vqa_data
|
50 |
|
51 |
print("vqa_data: ", len(self.vqa_data))
|
52 |
-
|
53 |
def __len__(self):
|
54 |
return self.samples_per_epoch
|
55 |
|
@@ -72,7 +77,11 @@ class VQADataset(torch.utils.data.Dataset):
|
|
72 |
img = cv2.imread(image_path)
|
73 |
images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
74 |
ori_size = images.shape[:2]
|
75 |
-
images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
|
|
|
|
|
|
|
|
|
76 |
image_token_len = (images_clip.shape[1] // 14) * (
|
77 |
images_clip.shape[2] // 14
|
78 |
) # FIXME: 14 is hardcoded patch size
|
|
|
10 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
11 |
|
12 |
from .conversation import get_default_conv_template
|
13 |
+
from .utils import (
|
14 |
+
DEFAULT_IM_END_TOKEN,
|
15 |
+
DEFAULT_IM_START_TOKEN,
|
16 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
17 |
+
DEFAULT_IMAGE_TOKEN,
|
18 |
+
)
|
19 |
+
|
20 |
|
21 |
class VQADataset(torch.utils.data.Dataset):
|
22 |
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
|
|
54 |
self.vqa_data = vqa_data
|
55 |
|
56 |
print("vqa_data: ", len(self.vqa_data))
|
57 |
+
|
58 |
def __len__(self):
|
59 |
return self.samples_per_epoch
|
60 |
|
|
|
77 |
img = cv2.imread(image_path)
|
78 |
images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
79 |
ori_size = images.shape[:2]
|
80 |
+
images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
|
81 |
+
"pixel_values"
|
82 |
+
][
|
83 |
+
0
|
84 |
+
] # preprocess images for clip
|
85 |
image_token_len = (images_clip.shape[1] // 14) * (
|
86 |
images_clip.shape[2] // 14
|
87 |
) # FIXME: 14 is hardcoded patch size
|