Update celle/celle.py
Browse files- celle/celle.py +4 -2
celle/celle.py
CHANGED
@@ -262,8 +262,10 @@ class ModelExtender(nn.Module):
|
|
262 |
|
263 |
# Set the number of output features and initialize the scaling layer
|
264 |
self.out_features = out_features
|
265 |
-
|
266 |
-
|
|
|
|
|
267 |
# Determine whether to freeze the model's parameters
|
268 |
self.fixed_embedding = fixed_embedding
|
269 |
if self.fixed_embedding:
|
|
|
262 |
|
263 |
# Set the number of output features and initialize the scaling layer
|
264 |
self.out_features = out_features
|
265 |
+
if self.in_features != self.out_features:
|
266 |
+
self.scale_layer = nn.Linear(self.in_features, self.out_features)
|
267 |
+
else:
|
268 |
+
self.scale_layer = nn.Identity()
|
269 |
# Determine whether to freeze the model's parameters
|
270 |
self.fixed_embedding = fixed_embedding
|
271 |
if self.fixed_embedding:
|