File size: 8,798 Bytes
fba8f5e
 
 
 
b212cb1
fba8f5e
 
 
b212cb1
fba8f5e
 
 
b212cb1
 
 
fba8f5e
b212cb1
 
fba8f5e
 
 
 
b212cb1
 
fba8f5e
 
b212cb1
 
 
fba8f5e
b212cb1
fba8f5e
 
 
 
 
b212cb1
 
fba8f5e
b212cb1
 
fba8f5e
 
b212cb1
 
 
fba8f5e
b212cb1
fba8f5e
 
 
b212cb1
 
fba8f5e
 
 
 
b212cb1
fba8f5e
 
 
 
 
 
b212cb1
 
fba8f5e
b212cb1
 
 
 
 
 
fba8f5e
b212cb1
 
fba8f5e
b212cb1
fba8f5e
 
b212cb1
fba8f5e
b212cb1
fba8f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b212cb1
 
 
fba8f5e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from math import ceil
from re import match
import seaborn as sns

from model import Model


import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from model import Model

class Data:
    """Container for input and output data"""
    # Initialise empty model as static class member for efficiency
    model = Model()

    def parse_seq(self, src: str):
        """Parse input sequence"""
        self.seq = src.strip().upper().replace('\n', '')
        if not all(x in self.model.alphabet for x in self.seq):
            raise RuntimeError("Unrecognised characters in sequence")

    def parse_sub(self, trg: str):
        """Parse input substitutions"""
        self.mode = None
        self.sub = list()
        self.trg = trg.strip().upper()
        self.resi = list()

        # Identify running mode
        if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg):
            # If single string of same length as sequence, seq vs seq mode
            self.mode = 'MUT'
            for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1):
                if src != trg:
                    self.sub.append(f"{src}{resi}{trg}")
                    self.resi.append(resi)
        else:
            self.trg = self.trg.split()
            if all(match(r'\d+', x) for x in self.trg):                         
                # If all strings are numbers, deep mutational scanning mode
                self.mode = 'DMS'
                for resi in map(int, self.trg):
                    src = self.seq[resi-1]
                    for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
                        self.sub.append(f"{src}{resi}{trg}")
                    self.resi.append(resi)
            elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
                # If all strings are of the form X#Y, single substitution mode
                self.mode = 'MUT'
                self.sub = self.trg
                self.resi = [int(x[1:-1]) for x in self.trg]
                for s, *resi, _ in self.trg:
                    if self.seq[int(''.join(resi))-1] != s:
                        raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
            else:
                self.mode = 'TMS'
                for resi, src in enumerate(self.seq, 1):
                    for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
                        self.sub.append(f"{src}{resi}{trg}")
                    self.resi.append(resi)

        self.sub = pd.DataFrame(self.sub, columns=['0'])

    def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None):
        "initialise data"
            # if model has changed, load new model
        if self.model.model_name != model_name:
            self.model_name = model_name
            self.model = Model(model_name)
        self.parse_seq(src)
        self.offset = 0
        self.parse_sub(trg)
        self.scoring_strategy = scoring_strategy
        self.token_probs = None
        self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
        self.out_str = None
        self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file

    def parse_output(self) -> None:
        "format output data for visualisation"
        if self.mode == 'TMS':
            self.process_tms_mode()
        else:
            if self.mode == 'DMS':
                self.sort_by_residue_and_score()
            elif self.mode == 'MUT':
                self.sort_by_score()
            else:
                raise RuntimeError(f"Unrecognised mode {self.mode}")
            if self.out_buffer:
                self.out.round(2).to_csv(self.out_buffer, index=False, header=False)
            self.out_str = (self.out.style
                            .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
                            .hide(axis=0)
                            .hide(axis=1)
                            .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
                            .to_html(justify='center'))

    def sort_by_score(self):
        self.out = self.out.sort_values(self.model_name, ascending=False)

    def sort_by_residue_and_score(self):
        self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
                            .sort_values(['resi', self.model_name], ascending=[True,False])
                            .groupby(['resi'])
                            .head(19)
                            .drop(['resi'], axis=1))
        self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
                            , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')

    def process_tms_mode(self):
        self.out = self.assign_resi_and_group()
        self.out = self.concat_and_set_axis()
        self.out /= self.out.abs().max().max()
        divs = self.calculate_divs()
        ncols = min(divs, key=lambda x: abs(x-60))
        nrows = ceil(self.out.shape[1]/ncols)
        ncols = self.adjust_ncols(ncols, nrows)
        self.plot_heatmap(ncols, nrows)

    def assign_resi_and_group(self):
        return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
                        .groupby(['resi'])
                        .head(19))

    def concat_and_set_axis(self):
        return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
                            .pipe(self.create_dataframe)
                            .sort_values(['0'], ascending=[True])
                            .drop(['resi', '0'], axis=1)
                            .set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                                       'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
                            .astype(float)
                            ) for x in range(self.out.shape[0]//19)]
                        , axis=1)
                        .set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns'))

    def create_dataframe(self, df):
        return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True)

    def calculate_divs(self):
        return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60]

    def adjust_ncols(self, ncols, nrows):
        while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]:
            ncols -= 1
        return ncols + 1

    def plot_heatmap(self, ncols, nrows):
        if nrows < 2:
            self.plot_single_heatmap()
        else:
            self.plot_multiple_heatmaps(ncols, nrows)

        if self.out_buffer:
            plt.savefig(self.out_buffer, format='svg')
            with open(self.out_buffer, 'r', encoding='utf-8') as f:
                self.out_str = f.read()

    def plot_single_heatmap(self):
        fig = plt.figure(figsize=(12, 6))
        sns.heatmap(self.out
                  , cmap='RdBu'
                  , cbar=False
                  , square=True
                  , xticklabels=1
                  , yticklabels=1
                  , center=0
                  , annot=self.out.map(lambda x: ' ' if x != 0 else '·')
                  , fmt='s'
                  , annot_kws={'size': 'xx-large'})
        fig.tight_layout()

    def plot_multiple_heatmaps(self, ncols, nrows):
        fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows))
        for i in range(nrows):
            tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
            label = tmp.map(lambda x: ' ' if x != 0 else '·')
            sns.heatmap(tmp
                      , ax=ax[i]
                      , cmap='RdBu'
                      , cbar=False
                      , square=True
                      , xticklabels=1
                      , yticklabels=1
                      , center=0
                      , annot=label
                      , fmt='s'
                      , annot_kws={'size': 'xx-large'})
            ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
            ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
        fig.tight_layout()
  
    def calculate(self):
        "run model and parse output"
        self.model.run_model(self)
        self.parse_output()
        return self

    def __str__(self):
        "return output data in DataFrame format"
        return str(self.out)

    def __repr__(self):
        "return output data in html format"
        return self.out_str