#!/usr/bin/env python3 import sys filenames = sys.argv[1:] MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart._make_causal_mask" MATCH_PATTERN_2 = "def _make_causal_mask(" MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart.prepare_4d_attention_mask" MATCH_PATTERN_2 = "def prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):" END_MATCH_PATTERN_2 = "" # MATCH_PATTERN_1 = "def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):" #MATCH_PATTERN_2 = "# create causal mask" # END_MATCH_PATTERN_2 = "def forward(" for filename in filenames: with open(filename, "r") as f: lines = f.readlines() new_lines = [] is_in_del = False for i, line in enumerate(lines): if line.strip().lstrip() == MATCH_PATTERN_1 and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == MATCH_PATTERN_2: print("suh") is_in_del = True elif line.strip().lstrip() == "" and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == END_MATCH_PATTERN_2: is_in_del = False if not is_in_del: new_lines.append(line) with open(filename, "w") as f: f.writelines(new_lines)