from torch import nn | |
FC_CLASS_REGISTRY = {'torch': nn.Linear} | |
try: | |
import transformer_engine.pytorch as te | |
FC_CLASS_REGISTRY['te'] = te.Linear | |
except: | |
pass |
from torch import nn | |
FC_CLASS_REGISTRY = {'torch': nn.Linear} | |
try: | |
import transformer_engine.pytorch as te | |
FC_CLASS_REGISTRY['te'] = te.Linear | |
except: | |
pass |