fix conflict
Browse files- gcvit/models/gcvit.py +1 -24
gcvit/models/gcvit.py
CHANGED
@@ -2,25 +2,12 @@ import numpy as np
|
|
2 |
import tensorflow as tf
|
3 |
|
4 |
from ..layers import Stem, GCViTLevel, Identity
|
5 |
-
from ..layers import Stem, GCViTLevel, Identity
|
6 |
|
7 |
|
8 |
|
9 |
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
10 |
TAG = 'v1.1.1'
|
11 |
NAME2CONFIG = {
|
12 |
-
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
13 |
-
'dim': 64,
|
14 |
-
'depths': (2, 2, 6, 2),
|
15 |
-
'num_heads': (2, 4, 8, 16),
|
16 |
-
'mlp_ratio': 3.,
|
17 |
-
'path_drop': 0.2},
|
18 |
-
'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
|
19 |
-
'dim': 64,
|
20 |
-
'depths': (3, 4, 6, 5),
|
21 |
-
'num_heads': (2, 4, 8, 16),
|
22 |
-
'mlp_ratio': 3.,
|
23 |
-
'path_drop': 0.2},
|
24 |
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
25 |
'dim': 64,
|
26 |
'depths': (2, 2, 6, 2),
|
@@ -94,7 +81,6 @@ class GCViT(tf.keras.Model):
|
|
94 |
self.levels = []
|
95 |
for i in range(len(depths)):
|
96 |
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
97 |
-
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
98 |
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
99 |
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
100 |
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
@@ -110,17 +96,14 @@ class GCViT(tf.keras.Model):
|
|
110 |
else:
|
111 |
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
112 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
113 |
-
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
114 |
|
115 |
-
|
116 |
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
|
117 |
self.num_classes = num_classes
|
118 |
if global_pool is not None:
|
119 |
self.global_pool = global_pool
|
120 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
121 |
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
122 |
-
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
123 |
-
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
124 |
|
125 |
def forward_features(self, inputs):
|
126 |
x = self.patch_embed(inputs)
|
@@ -137,7 +120,6 @@ class GCViT(tf.keras.Model):
|
|
137 |
x = self.pool(x)
|
138 |
if not pre_logits:
|
139 |
x = self.head(x)
|
140 |
-
x = self.head(x)
|
141 |
return x
|
142 |
|
143 |
def call(self, inputs, **kwargs):
|
@@ -153,8 +135,6 @@ class GCViT(tf.keras.Model):
|
|
153 |
def summary(self, input_shape=(224, 224, 3)):
|
154 |
return self.build_graph(input_shape).summary()
|
155 |
|
156 |
-
def summary(self, input_shape=(224, 224, 3)):
|
157 |
-
return self.build_graph(input_shape).summary()
|
158 |
|
159 |
# load standard models
|
160 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
@@ -179,7 +159,6 @@ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
|
|
179 |
model.load_weights(ckpt_path)
|
180 |
return model
|
181 |
|
182 |
-
def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
183 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
184 |
name = 'gcvit_xxtiny'
|
185 |
config = NAME2CONFIG[name]
|
@@ -215,7 +194,6 @@ def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **k
|
|
215 |
model.load_weights(ckpt_path)
|
216 |
return model
|
217 |
|
218 |
-
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
219 |
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
220 |
name = 'gcvit_small'
|
221 |
config = NAME2CONFIG[name]
|
@@ -229,7 +207,6 @@ def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **
|
|
229 |
model.load_weights(ckpt_path)
|
230 |
return model
|
231 |
|
232 |
-
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
233 |
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
234 |
name = 'gcvit_base'
|
235 |
config = NAME2CONFIG[name]
|
|
|
2 |
import tensorflow as tf
|
3 |
|
4 |
from ..layers import Stem, GCViTLevel, Identity
|
|
|
5 |
|
6 |
|
7 |
|
8 |
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
9 |
TAG = 'v1.1.1'
|
10 |
NAME2CONFIG = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
12 |
'dim': 64,
|
13 |
'depths': (2, 2, 6, 2),
|
|
|
81 |
self.levels = []
|
82 |
for i in range(len(depths)):
|
83 |
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
|
|
84 |
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
85 |
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
86 |
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
|
|
96 |
else:
|
97 |
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
98 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
|
|
99 |
|
100 |
+
|
101 |
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
|
102 |
self.num_classes = num_classes
|
103 |
if global_pool is not None:
|
104 |
self.global_pool = global_pool
|
105 |
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
106 |
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
|
|
|
|
107 |
|
108 |
def forward_features(self, inputs):
|
109 |
x = self.patch_embed(inputs)
|
|
|
120 |
x = self.pool(x)
|
121 |
if not pre_logits:
|
122 |
x = self.head(x)
|
|
|
123 |
return x
|
124 |
|
125 |
def call(self, inputs, **kwargs):
|
|
|
135 |
def summary(self, input_shape=(224, 224, 3)):
|
136 |
return self.build_graph(input_shape).summary()
|
137 |
|
|
|
|
|
138 |
|
139 |
# load standard models
|
140 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
|
|
159 |
model.load_weights(ckpt_path)
|
160 |
return model
|
161 |
|
|
|
162 |
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
163 |
name = 'gcvit_xxtiny'
|
164 |
config = NAME2CONFIG[name]
|
|
|
194 |
model.load_weights(ckpt_path)
|
195 |
return model
|
196 |
|
|
|
197 |
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
198 |
name = 'gcvit_small'
|
199 |
config = NAME2CONFIG[name]
|
|
|
207 |
model.load_weights(ckpt_path)
|
208 |
return model
|
209 |
|
|
|
210 |
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
211 |
name = 'gcvit_base'
|
212 |
config = NAME2CONFIG[name]
|