import os def prelude(): os.environ["PYTORCH_JIT"] = "0v" # patch for jit script # if we find `def expand_2d_or_3d_tensor(x,` in /usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py # patch it with `def expand_2d_or_3d_tensor(x: Tensor,` FAIRSEQ_CODE = ( "/usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py" ) if os.path.exists(FAIRSEQ_CODE): with open(FAIRSEQ_CODE, "r") as f: lines = f.readlines() with open(FAIRSEQ_CODE, "w") as f: for line in lines: if ( "def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):" in line ): f.write( "def expand_2d_or_3d_tensor(x: Tensor, trg_dim: int, padding_idx: int) -> Tensor:\n" ) else: f.write(line)