Update modeling_t5.py
Browse files- modeling_t5.py +12 -12
modeling_t5.py
CHANGED
@@ -274,18 +274,18 @@ class T5LayerNorm(nn.Module):
|
|
274 |
return self.weight * hidden_states
|
275 |
|
276 |
|
277 |
-
try:
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
except ImportError:
|
284 |
-
|
285 |
-
|
286 |
-
except Exception:
|
287 |
-
|
288 |
-
|
289 |
|
290 |
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
|
291 |
|
|
|
274 |
return self.weight * hidden_states
|
275 |
|
276 |
|
277 |
+
# try:
|
278 |
+
# from apex.normalization import FusedRMSNorm
|
279 |
+
|
280 |
+
# T5LayerNorm = FusedRMSNorm # noqa
|
281 |
+
|
282 |
+
# logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
|
283 |
+
# except ImportError:
|
284 |
+
# # using the normal T5LayerNorm
|
285 |
+
# pass
|
286 |
+
# except Exception:
|
287 |
+
# logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
288 |
+
# pass
|
289 |
|
290 |
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
|
291 |
|