File size: 69,120 Bytes
3189e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 |
# coding=utf-8
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch SAM model."""
import collections
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from transformers.models.sam.configuration_sam import SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from .configuration_sam_hq import SamHQConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SamConfig"
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
@dataclass
class SamVisionEncoderOutput(ModelOutput):
"""
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
layer to the pooler_output.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class SamImageSegmentationOutput(ModelOutput):
"""
Base class for Segment-Anything model's output
Args:
iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
The iou scores of the predicted masks.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
The predicted low resolutions masks. Needs to be post-processed by the processor
vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
iou_scores: torch.FloatTensor = None
pred_masks: torch.FloatTensor = None
vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class SamPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = (
image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
)
patch_size = (
patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
return embeddings
class SamMLPBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
self.act = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
class SamLayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class SamAttention(nn.Module):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(self, config, downsample_rate=None):
super().__init__()
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
self.internal_dim = config.hidden_size // downsample_rate
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(
batch * point_batch_size, n_tokens, num_attention_heads, c_per_head
)
return hidden_states.transpose(1, 2)
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(
batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head
)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = query.shape
attn = query @ key.permute(
0, 1, 3, 2
) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
sparse inputs (4) cross attention of dense inputs -> sparse inputs
Arguments:
config (`SamMaskDecoderConfig`):
The configuration file used to instantiate the block
attention_downsample_rate (*optionalk*, int, defaults to 2):
The downsample ratio of the block used to reduce the inner dim of the attention.
skip_first_layer_pe (*optional*, bool, defaults to `False`):
Whether or not to skip the addition of the query_point_embedding on the first layer.
"""
super().__init__()
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SamAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self,
queries: Tensor,
keys: Tensor,
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
# Cross attention block, tokens attending to image embedding
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out
queries = self.layer_norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.layer_norm3(queries)
# Cross attention block, image embedding attending to tokens
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
class SamTwoWayTransformer(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.layers = nn.ModuleList()
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SamAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
self,
point_embeddings: Tensor,
image_embeddings: Tensor,
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
# Prepare queries
queries = point_embeddings
keys = image_embeddings
# Apply transformer blocks and final layernorm
for layer in self.layers:
if target_embedding is not None:
queries += target_embedding
queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
class SamFeedForward(nn.Module):
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
):
super().__init__()
self.num_layers = num_layers
self.activation = nn.ReLU()
self.proj_in = nn.Linear(input_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
self.sigmoid_output = sigmoid_output
def forward(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))
hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
hidden_states = F.sigmoid(hidden_states)
return hidden_states
class SamMaskDecoderHQ(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.vision_encoder_dim = config.vision_encoder_dim
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
self.transformer = SamTwoWayTransformer(config)
# should we create a new class for this?
self.upscale_conv1 = nn.ConvTranspose2d(
self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2
)
self.upscale_conv2 = nn.ConvTranspose2d(
self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2
)
self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
self.activation = nn.GELU()
mlps_list = []
for _ in range(self.num_mask_tokens):
mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
self.iou_prediction_head = SamFeedForward(
self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
)
# HQ-SAM parameters
self.hf_token = nn.Embedding(1, self.hidden_size) # HQ-Ouptput-Token
self.hf_mlp = SamFeedForward(
self.hidden_size, self.hidden_size, self.hidden_size // 8, 3
) # corresponding new MLP layer for HQ-Ouptput-Token
self.num_mask_tokens = self.num_mask_tokens + 1
# three conv fusion layers for obtaining HQ-Feature
self.compress_vit_feat = nn.Sequential(
nn.ConvTranspose2d(self.vision_encoder_dim, self.hidden_size, kernel_size=2, stride=2),
SamLayerNorm(self.hidden_size, data_format="channels_first"),
nn.GELU(),
nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 8, kernel_size=2, stride=2),
)
self.embedding_encoder = nn.Sequential(
nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2),
SamLayerNorm(self.hidden_size // 4, data_format="channels_first"),
nn.GELU(),
nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2),
)
self.embedding_maskfeature = nn.Sequential(
nn.Conv2d(self.hidden_size // 8, self.hidden_size // 4, 3, 1, 1),
SamLayerNorm(self.hidden_size // 4, data_format="channels_first"),
nn.GELU(),
nn.Conv2d(self.hidden_size // 4, self.hidden_size // 8, 3, 1, 1),
)
def forward(
self,
image_embeddings: torch.Tensor,
image_positional_embeddings: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
intermediate_vision_embeddings: torch.Tensor,
hq_token_only: bool = False,
output_attentions: Optional[bool] = None,
attention_similarity: torch.Tensor = None,
target_embedding: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Args:
image_embeddings (`torch.Tensor`):
the embeddings from the image encoder
image_positional_embedding (`torch.Tensor`):
positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (`torch.Tensor`):
The embeddings of the points and boxes
dense_prompt_embeddings (`torch.Tensor`):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
vit_inter_features = intermediate_vision_embeddings[0].permute(
0, 3, 1, 2
) # early-layer ViT feature, after 1st global attention block in ViT
hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_inter_features)
# Concatenate output tokens
output_tokens = torch.cat(
[self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0
)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if sparse_prompt_embeddings.sum().item() != 0:
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
else:
tokens = output_tokens
point_embeddings = tokens.to(self.iou_token.weight.dtype)
# Expand per-image data in batch direction to be per-point
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
# Run the transformer, image_positional_embedding are consumed
point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
image_embeddings = image_embeddings.transpose(2, 3).reshape(
batch_size * point_batch_size, num_channels, height, width
)
upscaled_embedding_sam = self.upscale_conv1(image_embeddings)
upscaled_embedding_sam = self.activation(self.upscale_layer_norm(upscaled_embedding_sam))
upscaled_embedding_sam = self.activation(self.upscale_conv2(upscaled_embedding_sam))
upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(
batch_size * point_batch_size, 1, 1, 1
)
hyper_in_list = []
for i in range(self.num_mask_tokens):
mask_out_embedding = mask_tokens_out[:, :, i, :]
if i < self.num_mask_tokens - 1:
hyper = self.output_hypernetworks_mlps[i](mask_out_embedding)
else:
hyper = self.hf_mlp(mask_out_embedding)
hyper_in_list.append(hyper)
hyper_in = torch.stack(hyper_in_list, dim=2)
_, num_channels, height, width = upscaled_embedding_sam.shape
upscaled_embedding_sam = upscaled_embedding_sam.reshape(
batch_size, point_batch_size, num_channels, height * width
)
upscaled_embedding_hq = upscaled_embedding_hq.reshape(
batch_size, point_batch_size, num_channels, height * width
)
masks_sam = (hyper_in[:, :, : self.num_mask_tokens - 1] @ upscaled_embedding_sam).reshape(
batch_size, point_batch_size, -1, height, width
)
masks_hq = (hyper_in[:, :, self.num_mask_tokens - 1 :] @ upscaled_embedding_hq).reshape(
batch_size, point_batch_size, 1, height, width
)
masks = torch.cat([masks_sam, masks_hq], dim=2)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
# Select the correct mask or masks for output
if multimask_output:
# mask with highest score
mask_slice = slice(1, self.num_mask_tokens - 1)
iou_pred = iou_pred[:, :, mask_slice]
iou_pred, max_iou_idx = torch.max(iou_pred, dim=2)
masks_multi = masks[:, :, mask_slice, :, :]
masks_sam = masks_multi[
torch.arange(batch_size)[:, None, None],
torch.arange(point_batch_size)[None, :, None],
max_iou_idx,
:,
:,
]
else:
# single mask output, default
mask_slice = slice(0, 1)
iou_pred = iou_pred[:, :, mask_slice]
masks_sam = masks[:, :, mask_slice, :, :]
# masks = masks[:, :, mask_slice, :, :]
# iou_pred = iou_pred[:, :, mask_slice]
if hq_token_only:
masks = masks_hq
else:
masks = masks_sam + masks_hq
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
class SamPositionalEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.scale = config.hidden_size // 2
self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
def forward(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1]."""
coordinates = input_coords.clone()
if input_shape is not None:
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coordinates = 2 * coordinates - 1
coordinates = coordinates.to(self.positional_embedding.dtype)
coordinates = coordinates @ self.positional_embedding
coordinates = 2 * np.pi * coordinates
# outputs d_1 x ... x d_n x channel shape
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
class SamMaskEmbedding(nn.Module):
def __init__(self, config: SamPromptEncoderConfig):
super().__init__()
self.mask_input_channels = config.mask_input_channels // 4
self.activation = ACT2FN[config.hidden_act]
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = SamLayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)
def forward(self, masks):
hidden_states = self.conv1(masks)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.activation(hidden_states)
dense_embeddings = self.conv3(hidden_states)
return dense_embeddings
class SamPromptEncoder(nn.Module):
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding):
super().__init__()
self.shared_embedding = shared_patch_embedding
self.mask_embed = SamMaskEmbedding(config)
self.no_mask_embed = nn.Embedding(1, config.hidden_size)
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
self.input_image_size = config.image_size
self.point_embed = nn.ModuleList(
[nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
)
self.hidden_size = config.hidden_size
self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1)
padding_point = torch.zeros(target_point_shape, device=points.device)
padding_label = -torch.ones(target_labels_shape, device=labels.device)
points = torch.cat([points, padding_point], dim=2)
labels = torch.cat([labels, padding_label], dim=2)
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
# torch.where and expanding the labels tensor is required by the ONNX export
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
# This is required for the ONNX export. The dtype, device need to be explicitely
# specificed as otherwise torch.onnx.export interprets as double
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding = torch.where(
(labels == 0)[:, :, :, None],
point_embedding + self.point_embed[0].weight[None, None, :, :],
point_embedding,
)
point_embedding = torch.where(
(labels == 1)[:, :, :, None],
point_embedding + self.point_embed[1].weight[None, None, :, :],
point_embedding,
)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = boxes.shape[:2]
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding[:, :, 0, :] += self.point_embed[2].weight
corner_embedding[:, :, 1, :] += self.point_embed[3].weight
return corner_embedding
def forward(
self,
input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
input_labels: Optional[torch.Tensor],
input_boxes: Optional[torch.Tensor],
input_masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
batch_size = 1
target_device = self.shared_embedding.positional_embedding.device
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = point_embeddings
if input_boxes is not None:
batch_size = input_boxes.shape[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
else:
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
if input_masks is not None:
dense_embeddings = self.mask_embed(input_masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings
class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size):
super().__init__()
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
if window_size == 0
else (window_size, window_size)
)
self.num_attention_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
self.scale = head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size)
self.use_rel_pos = config.use_rel_pos
if self.use_rel_pos:
if input_size is None:
raise ValueError("Input size must be provided if using relative positional encoding.")
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
self,
attn: torch.Tensor,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
.permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (batch_size * nHead, height * width, channel)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(
0
)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.window_size = window_size
def window_partition(
self, hidden_states: torch.Tensor, window_size: int
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Args:
Partition into non-overlapping windows with padding if needed.
hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
size.
Returns:
windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
(pad_height, pad_width): padded height and width before partition
"""
batch_size, height, width, channel = hidden_states.shape
pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
pad_height, pad_width = height + pad_h, width + pad_w
hidden_states = hidden_states.reshape(
batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
)
windows = (
hidden_states.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.reshape(-1, window_size, window_size, channel)
)
return windows, (pad_height, pad_width)
def window_unpartition(
self,
windows: torch.Tensor,
window_size: int,
padding_shape: Tuple[int, int],
original_shape: Tuple[int, int],
) -> torch.Tensor:
"""
Args:
Window unpartition into original sequences and removing padding.
hidden_states (tensor):
input tokens with [batch_size * num_windows, window_size, window_size, channel].
window_size (int):
window size.
padding_shape (Tuple):
padded height and width (pad_height, pad_width).
original_shape (Tuple): original height and width (height, width) before padding.
Returns:
hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
"""
pad_height, pad_width = padding_shape
height, width = original_shape
batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
hidden_states = windows.reshape(
batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
)
hidden_states = (
hidden_states.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.reshape(batch_size, pad_height, pad_width, -1)
)
hidden_states = hidden_states[:, :height, :width, :].contiguous()
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
# Window partition
if self.window_size > 0:
height, width = hidden_states.shape[1], hidden_states.shape[2]
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
hidden_states, attn_weights = self.attn(
hidden_states=hidden_states,
output_attentions=output_attentions,
)
# Reverse window partition
if self.window_size > 0:
hidden_states = self.window_unpartition(
hidden_states, self.window_size, padding_shape, (height, width)
)
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SamVisionNeck(nn.Module):
def __init__(self, config: SamVisionConfig):
super().__init__()
self.config = config
self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
self.conv2 = nn.Conv2d(
config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False
)
self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
def forward(self, hidden_states):
hidden_states = hidden_states.permute(0, 3, 1, 2)
hidden_states = self.conv1(hidden_states)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
return hidden_states
class SamVisionEncoder(nn.Module):
def __init__(self, config: SamVisionConfig):
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_embed = SamPatchEmbeddings(config)
self.pos_embed = None
if config.use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(
1,
config.image_size // config.patch_size,
config.image_size // config.patch_size,
config.hidden_size,
)
)
self.layers = nn.ModuleList()
for i in range(config.num_hidden_layers):
layer = SamVisionLayer(
config,
window_size=config.window_size if i not in config.global_attn_indexes else 0,
)
self.layers.append(layer)
self.neck = SamVisionNeck(config)
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.patch_embed
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SamVisionEncoderOutput]:
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_hidden_states and layer_module.window_size == 0:
all_hidden_states = all_hidden_states + (hidden_states,)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return SamVisionEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class SamHQPreTrainedModel(PreTrainedModel):
config_class = SamHQConfig
base_model_prefix = "sam_hq"
main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
SAM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SamConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SAM_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
details.
input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
better results. The points can be obtained by passing a list of list of list to the processor that will
create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
coordinates of the point. If a different number of points is passed either for each image, or for each
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
computation of the embedding will be skipped for these points using the labels.
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
official implementation, there are 3 types of labels
- `1`: the point is a point that contains the object of interest
- `0`: the point is a point that does not contain the object of interest
- `-1`: the point corresponds to the background
We added the label:
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
The padding labels should be automatically done by the processor.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
size, the number of boxes per image and the coordinates of the top left and botton right point of the box.
In the order (`x1`, `y1`, `x2`, `y2`):
- `x1`: the x coordinate of the top left point of the input box
- `y1`: the y coordinate of the top left point of the input box
- `x2`: the x coordinate of the bottom right point of the input box
- `y2`: the y coordinate of the bottom right point of the input box
input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
multimask_output (`bool`, *optional*):
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
" optional 2D location and bounding boxes.",
SAM_START_DOCSTRING,
)
class SamHQModel(SamHQPreTrainedModel):
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config):
super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
self.vision_encoder = SamVisionEncoder(config.vision_config)
self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding)
if "vision_encoder_dim" not in config.mask_decoder_config.to_dict():
config.mask_decoder_config.vision_encoder_dim = config.vision_config.hidden_size
self.mask_decoder = SamMaskDecoderHQ(config.mask_decoder_config)
self.post_init()
def get_input_embeddings(self):
return self.vision_encoder.get_input_embeddings()
def get_image_wide_positional_embeddings(self):
size = self.config.prompt_encoder_config.image_embedding_size
target_device = self.shared_image_embedding.positional_embedding.device
target_dtype = self.shared_image_embedding.positional_embedding.dtype
grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / size
x_embed = x_embed / size
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
@torch.no_grad()
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
vision_output = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
return image_embeddings
@torch.no_grad()
def get_prompt_embeddings(
self,
input_points: Optional[torch.FloatTensor] = None,
input_labels: Optional[torch.LongTensor] = None,
input_boxes: Optional[torch.FloatTensor] = None,
input_masks: Optional[torch.LongTensor] = None,
):
r"""
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
Args:
input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
Optional input points for the prompt encoder. The padding of the point is automatically done by the
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
point. The model will output `point_batch_size` times 3 masks in total.
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
processor, or can be fed by the user.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
processor. users can also pass manually the input boxes.
input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
Optional input masks for the prompt encoder.
"""
prompt_output = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
return prompt_output
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_points: Optional[torch.FloatTensor] = None,
input_labels: Optional[torch.LongTensor] = None,
input_boxes: Optional[torch.FloatTensor] = None,
input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = False,
hq_token_only: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> List[Dict[str, torch.Tensor]]:
r"""
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoModel, AutoProcessor
>>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
>>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
>>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
>>> # Get segmentation mask
>>> outputs = model(**inputs)
>>> # Postprocess masks
>>> masks = processor.post_process_masks(
... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
... )
```
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
if pixel_values is not None and image_embeddings is not None:
raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
if input_points is not None and len(input_points.shape) != 4:
raise ValueError(
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
" got {}.".format(input_points.shape),
)
if input_boxes is not None and len(input_boxes.shape) != 3:
raise ValueError(
"The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
" got {}.".format(input_boxes.shape),
)
if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1]
box_batch_size = input_boxes.shape[1]
if point_batch_size != box_batch_size:
raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
point_batch_size, box_batch_size
)
)
image_positional_embeddings = self.get_image_wide_positional_embeddings()
# repeat with batch size
batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
vision_hidden_states = None
if pixel_values is not None:
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_outputs[0]
if output_hidden_states:
vision_hidden_states = vision_outputs[1]
if output_attentions:
vision_attentions = vision_outputs[-1]
if input_points is not None and input_labels is None:
input_labels = torch.ones_like(
input_points[:, :, :, 0], dtype=torch.int, device=input_points.device
)
if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
raise ValueError(
"The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
intermediate_vision_embeddings=vision_hidden_states[1:],
hq_token_only=hq_token_only,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
if not return_dict:
output = (iou_predictions, low_res_masks)
if output_hidden_states:
output = output + (vision_hidden_states,)
if output_attentions:
output = output + (vision_attentions, mask_decoder_attentions)
return output
return SamImageSegmentationOutput(
iou_scores=iou_predictions,
pred_masks=low_res_masks,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
mask_decoder_attentions=mask_decoder_attentions,
)
|