Spaces:
Runtime error
Runtime error
import os | |
def prelude(): | |
os.environ["PYTORCH_JIT"] = "0v" | |
# patch for jit script | |
# if we find `def expand_2d_or_3d_tensor(x,` in /usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py | |
# patch it with `def expand_2d_or_3d_tensor(x: Tensor,` | |
FAIRSEQ_CODE = ( | |
"/usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py" | |
) | |
if os.path.exists(FAIRSEQ_CODE): | |
with open(FAIRSEQ_CODE, "r") as f: | |
lines = f.readlines() | |
with open(FAIRSEQ_CODE, "w") as f: | |
for line in lines: | |
if ( | |
"def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):" | |
in line | |
): | |
f.write( | |
"def expand_2d_or_3d_tensor(x: Tensor, trg_dim: int, padding_idx: int) -> Tensor:\n" | |
) | |
else: | |
f.write(line) | |