Massimo G. Totaro commited on
Commit
fba8f5e
β€’
1 Parent(s): 82caf01

update fix

Browse files
Files changed (8) hide show
  1. .gitignore +3 -1
  2. LICENSE +11 -0
  3. README.md +2 -2
  4. app.py +90 -19
  5. data.py +169 -40
  6. instructions.md +39 -13
  7. model.py +74 -47
  8. requirements.txt +1 -1
.gitignore CHANGED
@@ -1 +1,3 @@
1
- */
 
 
 
1
+ Dockerfile
2
+ *.ipynb
3
+ */
LICENSE ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, Massimo G. Totaro All rights reserved.
2
+
3
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4
+
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -4,10 +4,10 @@ emoji: πŸ“ˆ
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.33.1
8
  app_file: app.py
9
  pinned: false
10
- license: eupl-1.1
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
+ license: bsd-2-clause
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,26 +1,97 @@
1
- from model import MODELS
2
- from data import Data
3
- import gradio as gr
4
  from tempfile import NamedTemporaryFile
 
 
 
 
 
 
 
5
 
6
- # scoring strategies
7
- SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"]
8
 
9
  def app(*argv):
10
- seq, trg, model_name, scoring_strategy, out_file, *_ = argv
11
- html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate()
12
- return html, gr.File.update(value=out_file.name, visible=True)
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md:
15
- gr.Markdown(md.read())
16
- seq = gr.Textbox(lines=2, label="Sequence", placeholder="Sequence here...", value='MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ')
17
- trg = gr.Textbox(lines=1, label="Substitutions", placeholder="Substitutions here...", value="61 214 19 30 122 140")
18
- model_name = gr.Dropdown(MODELS, label="Model", value=MODELS[1])
19
- scoring_strategy = gr.Dropdown(SCORING, label="Scoring strategy", value=SCORING[1])
20
- btn = gr.Button(value="Run")
21
- out = gr.HTML()
22
- bto = gr.File(value=out_file.name, visible=False, label="Download", file_count='single', interactive=False)
23
- btn.click(fn=app, inputs=[seq, trg, model_name, scoring_strategy, bto], outputs=[out, bto])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
  if __name__ == "__main__":
26
- demo.launch()
 
 
 
 
1
  from tempfile import NamedTemporaryFile
2
+ from gradio import Blocks, Button, Checkbox, Dropdown, Examples, File, HTML, Markdown, Textbox
3
+
4
+ from model import get_models
5
+ from data import Data
6
+
7
+ # Define scoring strategies
8
+ SCORING = ["wt-marginals", "masked-marginals"]
9
 
10
+ # Get available models
11
+ MODELS = get_models()
12
 
13
  def app(*argv):
14
+ """
15
+ Main application function
16
+ """
17
+ # Unpack the arguments
18
+ seq, trg, model_name, *_ = argv
19
+ scoring = SCORING[scoring_strategy.value]
20
+ try:
21
+ # Calculate the data based on the input parameters
22
+ data = Data(seq, trg, model_name, scoring, out_file).calculate()
23
+ except Exception as e:
24
+ # If an error occurs, return an HTML error message
25
+ return f'<!DOCTYPE html><html><body><h1 style="background-color:#F70D1A;text-align:center;">Error: {str(e)}</h1></body></html>', None
26
+ # If no error occurs, return the calculated data
27
+ return repr(data), File(value=out_file.name, visible=True)
28
 
29
+ # Create the Gradio interface
30
+ with open("instructions.md", "r", encoding="utf-8") as md,\
31
+ NamedTemporaryFile(mode='w+') as out_file,\
32
+ Blocks() as esm_scan:
33
+
34
+ # Define the interface components
35
+ Markdown(md.read())
36
+ seq = Textbox(
37
+ lines=2,
38
+ label="Sequence",
39
+ placeholder="FASTA sequence here...",
40
+ value=''
41
+ )
42
+ trg = Textbox(
43
+ lines=1,
44
+ label="Substitutions",
45
+ placeholder="Substitutions here...",
46
+ value=""
47
+ )
48
+ model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
49
+ scoring_strategy = Checkbox(value=True, label="Use masked-marginals scoring")
50
+ btn = Button(value="Run")
51
+ out = HTML()
52
+ bto = File(
53
+ value=out_file.name,
54
+ visible=False,
55
+ label="Download",
56
+ file_count='single',
57
+ interactive=False
58
+ )
59
+ btn.click(
60
+ fn=app,
61
+ inputs=[seq, trg, model_name],
62
+ outputs=[out, bto]
63
+ )
64
+ ex = Examples(
65
+ examples=[
66
+ [
67
+ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
68
+ "deep mutational scanning",
69
+ "facebook/esm2_t6_8M_UR50D"
70
+ ],
71
+ [
72
+ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
73
+ "217 218 219",
74
+ "facebook/esm2_t12_35M_UR50D"
75
+ ],
76
+ [
77
+ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
78
+ "R218K R218S R218N R218A R218V R218D",
79
+ "facebook/esm2_t30_150M_UR50D",
80
+ ],
81
+ [
82
+ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
83
+ "MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
84
+ "facebook/esm2_t33_650M_UR50D",
85
+ ],
86
+ ],
87
+ inputs=[seq,
88
+ trg,
89
+ model_name],
90
+ outputs=[out,
91
+ bto],
92
+ fn=app
93
+ )
94
 
95
+ # Launch the Gradio interface
96
  if __name__ == "__main__":
97
+ esm_scan.launch()
data.py CHANGED
@@ -1,80 +1,209 @@
 
 
 
 
1
  from model import Model
 
 
 
2
  import pandas as pd
3
- from re import match
 
 
4
 
5
  class Data:
6
  """Container for input and output data"""
7
- # initialise empty model as static class member for efficiency
8
  model = Model()
9
 
10
- def parse_seq(self, src:str):
11
- "parse input sequence"
12
- self.seq = src.strip().upper()
13
- if not all(x in self.model.alphabet for x in src):
14
  raise RuntimeError("Unrecognised characters in sequence")
15
 
16
- def parse_sub(self, trg:str):
17
- "parse input substitutions"
18
  self.mode = None
19
  self.sub = list()
20
  self.trg = trg.strip().upper()
 
21
 
22
- # identify running mode
23
- if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode
24
- self.mode = 'SVS'
25
- for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1):
 
26
  if src != trg:
27
  self.sub.append(f"{src}{resi}{trg}")
 
28
  else:
29
  self.trg = self.trg.split()
30
- if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode
 
31
  self.mode = 'DMS'
32
  for resi in map(int, self.trg):
33
  src = self.seq[resi-1]
34
- for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
35
  self.sub.append(f"{src}{resi}{trg}")
36
- elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode
 
 
37
  self.mode = 'MUT'
38
  self.sub = self.trg
 
 
 
 
39
  else:
40
- raise RuntimeError("Unrecognised running mode; wrong inputs?")
41
-
 
 
 
 
42
  self.sub = pd.DataFrame(self.sub, columns=['0'])
43
 
44
- def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
45
  "initialise data"
46
  # if model has changed, load new model
47
  if self.model.model_name != model_name:
48
  self.model_name = model_name
49
  self.model = Model(model_name)
50
  self.parse_seq(src)
 
51
  self.parse_sub(trg)
52
  self.scoring_strategy = scoring_strategy
 
53
  self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
54
- self.out_buffer = out_file.name
 
55
 
56
- def parse_output(self) -> str:
57
  "format output data for visualisation"
58
- if self.mode == 'MUT': # if single substitution mode, sort by score
59
- self.out = self.out.sort_values(self.model_name, ascending=False)
60
- elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score
61
- self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence
62
- .sort_values(['resi', self.model_name], ascending=[True,False])
63
- .groupby(['resi'])
64
- .head(19)
65
- .drop(['resi'], axis=1)).iloc[19*x:19*(x+1)]
66
- .reset_index(drop=True) for x in range(self.out.shape[0]//19)]
67
- , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
68
- # save to temporary file to be downloaded
69
- self.out.round(2).to_csv(self.out_buffer, index=False, header=False)
70
- return (self.out.style
71
- .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
72
- .hide(axis=0)
73
- .hide(axis=1)
74
- .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
75
- .to_html(justify='center'))
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def calculate(self):
78
  "run model and parse output"
79
  self.model.run_model(self)
80
- return self.parse_output()
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+ from re import match
3
+ import seaborn as sns
4
+
5
  from model import Model
6
+
7
+
8
+ import matplotlib.pyplot as plt
9
  import pandas as pd
10
+ import seaborn as sns
11
+
12
+ from model import Model
13
 
14
  class Data:
15
  """Container for input and output data"""
16
+ # Initialise empty model as static class member for efficiency
17
  model = Model()
18
 
19
+ def parse_seq(self, src: str):
20
+ """Parse input sequence"""
21
+ self.seq = src.strip().upper().replace('\n', '')
22
+ if not all(x in self.model.alphabet for x in self.seq):
23
  raise RuntimeError("Unrecognised characters in sequence")
24
 
25
+ def parse_sub(self, trg: str):
26
+ """Parse input substitutions"""
27
  self.mode = None
28
  self.sub = list()
29
  self.trg = trg.strip().upper()
30
+ self.resi = list()
31
 
32
+ # Identify running mode
33
+ if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg):
34
+ # If single string of same length as sequence, seq vs seq mode
35
+ self.mode = 'MUT'
36
+ for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1):
37
  if src != trg:
38
  self.sub.append(f"{src}{resi}{trg}")
39
+ self.resi.append(resi)
40
  else:
41
  self.trg = self.trg.split()
42
+ if all(match(r'\d+', x) for x in self.trg):
43
+ # If all strings are numbers, deep mutational scanning mode
44
  self.mode = 'DMS'
45
  for resi in map(int, self.trg):
46
  src = self.seq[resi-1]
47
+ for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
48
  self.sub.append(f"{src}{resi}{trg}")
49
+ self.resi.append(resi)
50
+ elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
51
+ # If all strings are of the form X#Y, single substitution mode
52
  self.mode = 'MUT'
53
  self.sub = self.trg
54
+ self.resi = [int(x[1:-1]) for x in self.trg]
55
+ for s, *resi, _ in self.trg:
56
+ if self.seq[int(''.join(resi))-1] != s:
57
+ raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
58
  else:
59
+ self.mode = 'TMS'
60
+ for resi, src in enumerate(self.seq, 1):
61
+ for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
62
+ self.sub.append(f"{src}{resi}{trg}")
63
+ self.resi.append(resi)
64
+
65
  self.sub = pd.DataFrame(self.sub, columns=['0'])
66
 
67
+ def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None):
68
  "initialise data"
69
  # if model has changed, load new model
70
  if self.model.model_name != model_name:
71
  self.model_name = model_name
72
  self.model = Model(model_name)
73
  self.parse_seq(src)
74
+ self.offset = 0
75
  self.parse_sub(trg)
76
  self.scoring_strategy = scoring_strategy
77
+ self.token_probs = None
78
  self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
79
+ self.out_str = None
80
+ self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file
81
 
82
+ def parse_output(self) -> None:
83
  "format output data for visualisation"
84
+ if self.mode == 'TMS':
85
+ self.process_tms_mode()
86
+ else:
87
+ if self.mode == 'DMS':
88
+ self.sort_by_residue_and_score()
89
+ elif self.mode == 'MUT':
90
+ self.sort_by_score()
91
+ else:
92
+ raise RuntimeError(f"Unrecognised mode {self.mode}")
93
+ if self.out_buffer:
94
+ self.out.round(2).to_csv(self.out_buffer, index=False, header=False)
95
+ self.out_str = (self.out.style
96
+ .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
97
+ .hide(axis=0)
98
+ .hide(axis=1)
99
+ .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
100
+ .to_html(justify='center'))
101
+
102
+ def sort_by_score(self):
103
+ self.out = self.out.sort_values(self.model_name, ascending=False)
104
+
105
+ def sort_by_residue_and_score(self):
106
+ self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
107
+ .sort_values(['resi', self.model_name], ascending=[True,False])
108
+ .groupby(['resi'])
109
+ .head(19)
110
+ .drop(['resi'], axis=1))
111
+ self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
112
+ , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
113
+
114
+ def process_tms_mode(self):
115
+ self.out = self.assign_resi_and_group()
116
+ self.out = self.concat_and_set_axis()
117
+ self.out /= self.out.abs().max().max()
118
+ divs = self.calculate_divs()
119
+ ncols = min(divs, key=lambda x: abs(x-60))
120
+ nrows = ceil(self.out.shape[1]/ncols)
121
+ ncols = self.adjust_ncols(ncols, nrows)
122
+ self.plot_heatmap(ncols, nrows)
123
+
124
+ def assign_resi_and_group(self):
125
+ return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
126
+ .groupby(['resi'])
127
+ .head(19))
128
+
129
+ def concat_and_set_axis(self):
130
+ return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
131
+ .pipe(self.create_dataframe)
132
+ .sort_values(['0'], ascending=[True])
133
+ .drop(['resi', '0'], axis=1)
134
+ .set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
135
+ 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
136
+ .astype(float)
137
+ ) for x in range(self.out.shape[0]//19)]
138
+ , axis=1)
139
+ .set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns'))
140
+
141
+ def create_dataframe(self, df):
142
+ return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True)
143
+
144
+ def calculate_divs(self):
145
+ return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60]
146
+
147
+ def adjust_ncols(self, ncols, nrows):
148
+ while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]:
149
+ ncols -= 1
150
+ return ncols + 1
151
+
152
+ def plot_heatmap(self, ncols, nrows):
153
+ if nrows < 2:
154
+ self.plot_single_heatmap()
155
+ else:
156
+ self.plot_multiple_heatmaps(ncols, nrows)
157
+
158
+ if self.out_buffer:
159
+ plt.savefig(self.out_buffer, format='svg')
160
+ with open(self.out_buffer, 'r', encoding='utf-8') as f:
161
+ self.out_str = f.read()
162
+
163
+ def plot_single_heatmap(self):
164
+ fig = plt.figure(figsize=(12, 6))
165
+ sns.heatmap(self.out
166
+ , cmap='RdBu'
167
+ , cbar=False
168
+ , square=True
169
+ , xticklabels=1
170
+ , yticklabels=1
171
+ , center=0
172
+ , annot=self.out.map(lambda x: ' ' if x != 0 else 'Β·')
173
+ , fmt='s'
174
+ , annot_kws={'size': 'xx-large'})
175
+ fig.tight_layout()
176
+
177
+ def plot_multiple_heatmaps(self, ncols, nrows):
178
+ fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows))
179
+ for i in range(nrows):
180
+ tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
181
+ label = tmp.map(lambda x: ' ' if x != 0 else 'Β·')
182
+ sns.heatmap(tmp
183
+ , ax=ax[i]
184
+ , cmap='RdBu'
185
+ , cbar=False
186
+ , square=True
187
+ , xticklabels=1
188
+ , yticklabels=1
189
+ , center=0
190
+ , annot=label
191
+ , fmt='s'
192
+ , annot_kws={'size': 'xx-large'})
193
+ ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
194
+ ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
195
+ fig.tight_layout()
196
+
197
  def calculate(self):
198
  "run model and parse output"
199
  self.model.run_model(self)
200
+ self.parse_output()
201
+ return self
202
+
203
+ def __str__(self):
204
+ "return output data in DataFrame format"
205
+ return str(self.out)
206
+
207
+ def __repr__(self):
208
+ "return output data in html format"
209
+ return self.out_str
instructions.md CHANGED
@@ -1,13 +1,39 @@
1
- # **ESM zero-shot variant prediction**
2
- this was inspired from this [paper](https://doi.org/10.1101/2021.07.09.450648) and adaptated from this [repo](https://github.com/facebookresearch/esm/tree/main/esm)
3
-
4
- #### **Instructions**
5
- - in the 'sequence' text box the protein full amino acid sequence that is to be analysed must be given, jolly charachters (e.g. -X.B) are supported (but at the moment the visualisation does not show the correct results)
6
- - there's three running modes that can be chosen, depending on the input in the 'substitution' box:
7
- - if another sequence is given, the positions that are different between the two will be evaluated (NB the sequences must be of the same length) and their score returned
8
- - if a list of integers is given, a deep mutational scan will be performed at those positions in the input sequence and the scores for the amino acids, different from the original one, will be returned
9
- - if a single substitution or a list thereof is given (in the form of **B008S**), the single substitution score is returned
10
- - you can choose which ESM model to use for the calculations, these models are the ones that are available at runtime on Hugging Face Model Hub
11
- - there's 2 scoring strategies available: wt-marginals and masked marginals; the first one is faster, but less accurate, the second one considers the sequence context more thoroughly, but is sensibly slower (the run time scales linearly with sequence length)
12
- - the results will be shown in a table, with color coding and sorted by fitness (if performing a deep mutational scan)
13
- - the output data is available for download from the box at the bottom as a CSV file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **ESM-Scan**
2
+ Calculate the <u>fitness of single amino acid substitutions</u> on proteins, using a [zero-shot](https://doi.org/10.1101/2021.07.09.450648) [language model predictor](https://github.com/facebookresearch/esm)
3
+
4
+ <details>
5
+ <summary> <b> USAGE INSTRUCTIONS </b> </summary>
6
+
7
+ ### **Setup**
8
+ No setup is required, just fill the input boxes with the required data and click on the `Run` button.
9
+ A list of examples can be found at the bottom of the page, click on them to autofill the fields.
10
+ If the server is not used for some time, it will go into standby.
11
+ Running a calculation resumes the tool from standby, the first run might take longer due to startup and model loading.
12
+
13
+ ### **Input**
14
+ - write the protein full amino acid sequence to be analysed in the **Sequence** text box
15
+ jolly charachters (e.g. `-X.B`) can be inserted but, at the moment, visualisation cannot handle them
16
+ - write the substitutions to test in the **Substitutions** box
17
+ there are three running modes that can be used, depending on the input:
18
+ + *single substitution* or list thereof (in the form of `R218K R218W`): the single substitution is scored
19
+ + *residue position* or list thereof: all possible substitutions will be evaluated
20
+ + *same-length sequence*: the differing amino acid substitutions will be evaluated, one by one
21
+ + any other *different input*: a deep mutational scan of the full sequence will be performed
22
+ - the ESM model to use for the calculations can be chosen among those that are available on Hugging Face Model Hub;
23
+ `esm2_t33_650M_UR50D` offers the best expense-accuracy tradeoff[*](https://doi.org/10.1126/science.ade2574)
24
+ - the `masked-marginals` scoring strategy considers sequence context at inference time, being slower but more accurate;
25
+ in case of long runtimes, you can tick the box off to speed the calculations up significantly, sacrificing accuracy
26
+ - when running a deep mutational scan, it is recommended to use smaller models (8M, 35M, 150M parameters), since the runtime is significant, especially for longer sequences and the server might be overloaded;
27
+ over 30 min might be necessary for calculating a 300-residue-long sequence with larger models
28
+ in general, accuracy is influenced significantly by the scoring strategy and less so by the model size, so it is suggested to reduce the latter first when optimising for runtime;
29
+ the scoring strategy computational cost scales with the number of substitutions tested, while the model’s with the wild-type sequence length
30
+ - it is possible to calculate the effect of multiple concurrent substitutions, but this has to be done manually, by changing the input sequence and running the calculation again
31
+
32
+ ### **Output**
33
+ Your results will be shown in a color-coded table, except for the deep mutational scan which will yield a heatmap.
34
+ The output data can be downloaded from the box at the bottom.
35
+ File extensions are not supported by the server and need to be appended to the filenames after downloading:
36
+ - `CSV` for tables
37
+ - `SVG` for full-sequence deep mutational scan
38
+
39
+ </details>
model.py CHANGED
@@ -1,72 +1,99 @@
1
  from huggingface_hub import HfApi, ModelFilter
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
 
 
4
 
5
- # fetch suitable ESM models from HuggingFace Hub
6
- MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
7
- if not any(MODELS):
8
- raise RuntimeError("Error while retrieving models from HuggingFace Hub")
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
10
  class Model:
11
- """Wrapper for ESM models"""
12
- def __init__(self, model_name:str=""):
13
- "load selected model and tokenizer"
14
  self.model_name = model_name
15
  if model_name:
16
  self.model = AutoModelForMaskedLM.from_pretrained(model_name)
17
  self.batch_converter = AutoTokenizer.from_pretrained(model_name)
18
  self.alphabet = self.batch_converter.get_vocab()
 
19
  if torch.cuda.is_available():
20
  self.model = self.model.cuda()
21
 
22
- def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
23
- "run model on batch of tokens"
24
- return self.model(batch_tokens)["logits"]
25
-
26
- def __lshift__(self, input:str) -> torch.Tensor:
27
- "convert input string to batch of tokens"
28
- return self.batch_converter(input, return_tensors="pt")["input_ids"]
29
 
30
- def __getitem__(self, key:str) -> int:
31
- "get token ID from character"
32
  return self.alphabet[key]
33
-
34
  def run_model(self, data):
35
- "run model on data"
36
  def label_row(row, token_probs):
37
- "label row with score"
 
38
  wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
 
39
  score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
40
  return score.item()
41
-
42
- batch_tokens = self<<data.seq
43
 
44
- # run model with selected scoring strategy (info thereof available in the original ESM paper)
45
- if data.scoring_strategy.startswith("wt-marginals"):
46
- with torch.no_grad():
47
- token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
48
- data.out[self.model_name] = data.sub.apply(
49
- lambda row: label_row(
50
- row['0'],
51
- token_probs,
52
- ),
53
- axis=1,
54
- )
55
- elif data.scoring_strategy.startswith("masked-marginals"):
56
  all_token_probs = []
 
57
  for i in range(batch_tokens.size()[1]):
58
- batch_tokens_masked = batch_tokens.clone()
59
- batch_tokens_masked[0, i] = self['<mask>']
60
- with torch.no_grad():
61
- token_probs = torch.log_softmax(
62
- self>>batch_tokens_masked, dim=-1
63
- )
64
- all_token_probs.append(token_probs[:, i])
 
 
 
 
 
 
 
 
 
65
  token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
66
- data.out[self.model_name] = data.sub.apply(
67
- lambda row: label_row(
68
- row['0'],
69
- token_probs,
70
- ),
71
- axis=1,
72
- )
 
 
 
1
  from huggingface_hub import HfApi, ModelFilter
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+ from transformers.tokenization_utils_base import BatchEncoding
5
+ from transformers.modeling_outputs import MaskedLMOutput
6
 
7
+ # Function to fetch suitable ESM models from HuggingFace Hub
8
+ def get_models() -> list[None|str]:
9
+ """Fetch suitable ESM models from HuggingFace Hub."""
10
+ if not any(
11
+ out := [
12
+ m.modelId for m in HfApi().list_models(
13
+ filter=ModelFilter(
14
+ author="facebook", model_name="esm", task="fill-mask"
15
+ ),
16
+ sort="lastModified",
17
+ direction=-1
18
+ )
19
+ ]
20
+ ):
21
+ raise RuntimeError("Error while retrieving models from HuggingFace Hub")
22
+ return out
23
 
24
+ # Class to wrap ESM models
25
  class Model:
26
+ """Wrapper for ESM models."""
27
+ def __init__(self, model_name: str = ""):
28
+ """Load selected model and tokenizer."""
29
  self.model_name = model_name
30
  if model_name:
31
  self.model = AutoModelForMaskedLM.from_pretrained(model_name)
32
  self.batch_converter = AutoTokenizer.from_pretrained(model_name)
33
  self.alphabet = self.batch_converter.get_vocab()
34
+ # Check if CUDA is available and if so, use it
35
  if torch.cuda.is_available():
36
  self.model = self.model.cuda()
37
 
38
+ def tokenise(self, input: str) -> BatchEncoding:
39
+ """Convert input string to batch of tokens."""
40
+ return self.batch_converter(input, return_tensors="pt")
41
+
42
+ def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
43
+ """Run model on batch of tokens."""
44
+ return self.model(batch_tokens, **kwargs)
45
 
46
+ def __getitem__(self, key: str) -> int:
47
+ """Get token ID from character."""
48
  return self.alphabet[key]
49
+
50
  def run_model(self, data):
51
+ """Run model on data."""
52
  def label_row(row, token_probs):
53
+ """Label row with score."""
54
+ # Extract wild type, index and mutant type from the row
55
  wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
56
+ # Calculate the score as the difference between the token probabilities of the mutant type and the wild type
57
  score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
58
  return score.item()
 
 
59
 
60
+ # Tokenise the sequence data
61
+ batch_tokens = self.tokenise(data.seq).input_ids
62
+
63
+ # Calculate the token probabilities without updating the model parameters
64
+ with torch.no_grad():
65
+ token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1)
66
+ # Store the token probabilities in the data
67
+ data.token_probs = token_probs.cpu().numpy()
68
+
69
+ # If the scoring strategy starts with "masked-marginals"
70
+ if data.scoring_strategy.startswith("masked-marginals"):
 
71
  all_token_probs = []
72
+ # For each token in the batch
73
  for i in range(batch_tokens.size()[1]):
74
+ # If the token is in the list of residues
75
+ if i in data.resi:
76
+ # Clone the batch tokens and mask the current token
77
+ batch_tokens_masked = batch_tokens.clone()
78
+ batch_tokens_masked[0, i] = self['<mask>']
79
+ # Calculate the masked token probabilities
80
+ with torch.no_grad():
81
+ masked_token_probs = torch.log_softmax(
82
+ self(batch_tokens_masked).logits, dim=-1
83
+ )
84
+ else:
85
+ # If the token is not in the list of residues, use the original token probabilities
86
+ masked_token_probs = token_probs
87
+ # Append the token probabilities to the list
88
+ all_token_probs.append(masked_token_probs[:, i])
89
+ # Concatenate all token probabilities
90
  token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
91
+
92
+ # Apply the label_row function to each row of the substitutions dataframe
93
+ data.out[self.model_name] = data.sub.apply(
94
+ lambda row: label_row(
95
+ row['0'],
96
+ token_probs,
97
+ ),
98
+ axis=1,
99
+ )
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio
2
- huggingface_hub
3
  pandas
 
4
  torch
5
  transformers
 
1
  gradio
 
2
  pandas
3
+ seaborn
4
  torch
5
  transformers