from math import ceil from re import match import seaborn as sns from model import Model import matplotlib.pyplot as plt import pandas as pd import seaborn as sns from model import Model class Data: """Container for input and output data""" # Initialise empty model as static class member for efficiency model = Model() def parse_seq(self, src: str): """Parse input sequence""" self.seq = src.strip().upper().replace('\n', '') if not all(x in self.model.alphabet for x in self.seq): raise RuntimeError("Unrecognised characters in sequence") def parse_sub(self, trg: str): """Parse input substitutions""" self.mode = None self.sub = list() self.trg = trg.strip().upper() self.resi = list() # Identify running mode if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg): # If single string of same length as sequence, seq vs seq mode self.mode = 'MUT' for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1): if src != trg: self.sub.append(f"{src}{resi}{trg}") self.resi.append(resi) else: self.trg = self.trg.split() if all(match(r'\d+', x) for x in self.trg): # If all strings are numbers, deep mutational scanning mode self.mode = 'DMS' for resi in map(int, self.trg): src = self.seq[resi-1] for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''): self.sub.append(f"{src}{resi}{trg}") self.resi.append(resi) elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # If all strings are of the form X#Y, single substitution mode self.mode = 'MUT' self.sub = self.trg self.resi = [int(x[1:-1]) for x in self.trg] for s, *resi, _ in self.trg: if self.seq[int(''.join(resi))-1] != s: raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}") else: self.mode = 'TMS' for resi, src in enumerate(self.seq, 1): for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''): self.sub.append(f"{src}{resi}{trg}") self.resi.append(resi) self.sub = pd.DataFrame(self.sub, columns=['0']) def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None): "initialise data" # if model has changed, load new model if self.model.model_name != model_name: self.model_name = model_name self.model = Model(model_name) self.parse_seq(src) self.offset = 0 self.parse_sub(trg) self.scoring_strategy = scoring_strategy self.token_probs = None self.out = pd.DataFrame(self.sub, columns=['0', self.model_name]) self.out_str = None self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file def parse_output(self) -> None: "format output data for visualisation" if self.mode == 'TMS': self.process_tms_mode() else: if self.mode == 'DMS': self.sort_by_residue_and_score() elif self.mode == 'MUT': self.sort_by_score() else: raise RuntimeError(f"Unrecognised mode {self.mode}") if self.out_buffer: self.out.round(2).to_csv(self.out_buffer, index=False, header=False) self.out_str = (self.out.style .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x) .hide(axis=0) .hide(axis=1) .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8) .to_html(justify='center')) def sort_by_score(self): self.out = self.out.sort_values(self.model_name, ascending=False) def sort_by_residue_and_score(self): self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) .sort_values(['resi', self.model_name], ascending=[True,False]) .groupby(['resi']) .head(19) .drop(['resi'], axis=1)) self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)] , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns') def process_tms_mode(self): self.out = self.assign_resi_and_group() self.out = self.concat_and_set_axis() self.out /= self.out.abs().max().max() divs = self.calculate_divs() ncols = min(divs, key=lambda x: abs(x-60)) nrows = ceil(self.out.shape[1]/ncols) ncols = self.adjust_ncols(ncols, nrows) self.plot_heatmap(ncols, nrows) def assign_resi_and_group(self): return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) .groupby(['resi']) .head(19)) def concat_and_set_axis(self): return (pd.concat([(self.out.iloc[19*x:19*(x+1)] .pipe(self.create_dataframe) .sort_values(['0'], ascending=[True]) .drop(['resi', '0'], axis=1) .set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']) .astype(float) ) for x in range(self.out.shape[0]//19)] , axis=1) .set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns')) def create_dataframe(self, df): return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True) def calculate_divs(self): return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60] def adjust_ncols(self, ncols, nrows): while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]: ncols -= 1 return ncols + 1 def plot_heatmap(self, ncols, nrows): if nrows < 2: self.plot_single_heatmap() else: self.plot_multiple_heatmaps(ncols, nrows) if self.out_buffer: plt.savefig(self.out_buffer, format='svg') with open(self.out_buffer, 'r', encoding='utf-8') as f: self.out_str = f.read() def plot_single_heatmap(self): fig = plt.figure(figsize=(12, 6)) sns.heatmap(self.out , cmap='RdBu' , cbar=False , square=True , xticklabels=1 , yticklabels=1 , center=0 , annot=self.out.map(lambda x: ' ' if x != 0 else '·') , fmt='s' , annot_kws={'size': 'xx-large'}) fig.tight_layout() def plot_multiple_heatmaps(self, ncols, nrows): fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows)) for i in range(nrows): tmp = self.out.iloc[:,i*ncols:(i+1)*ncols] label = tmp.map(lambda x: ' ' if x != 0 else '·') sns.heatmap(tmp , ax=ax[i] , cmap='RdBu' , cbar=False , square=True , xticklabels=1 , yticklabels=1 , center=0 , annot=label , fmt='s' , annot_kws={'size': 'xx-large'}) ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0) ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90) fig.tight_layout() def calculate(self): "run model and parse output" self.model.run_model(self) self.parse_output() return self def __str__(self): "return output data in DataFrame format" return str(self.out) def __repr__(self): "return output data in html format" return self.out_str