MassimoGregorioTotaro commited on
Commit
b212cb1
1 Parent(s): 462a012

fix ok, reformatting

Browse files
Files changed (3) hide show
  1. app.py +4 -157
  2. data.py +80 -0
  3. model.py +72 -0
app.py CHANGED
@@ -1,167 +1,14 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import HfApi, ModelFilter
3
- import pandas as pd
4
- from re import match
5
  from tempfile import NamedTemporaryFile
6
- import torch
7
- from transformers import AutoTokenizer, AutoModelForMaskedLM
8
-
9
- # fetch suitable ESM models from HuggingFace Hub
10
- MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
11
- if not any(MODELS):
12
- raise RuntimeError("Error while retrieving models from HuggingFace Hub")
13
 
14
  # scoring strategies
15
  SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"]
16
 
17
- class Model:
18
- """Wrapper for ESM models"""
19
- def __init__(self, model_name:str=""):
20
- "load selected model and tokenizer"
21
- self.model_name = model_name
22
- if model_name:
23
- self.model = AutoModelForMaskedLM.from_pretrained(model_name)
24
- self.batch_converter = AutoTokenizer.from_pretrained(model_name)
25
- self.alphabet = self.batch_converter.get_vocab()
26
- if torch.cuda.is_available():
27
- self.model = self.model.cuda()
28
-
29
- def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
30
- "run model on batch of tokens"
31
- return self.model(batch_tokens)["logits"]
32
-
33
- def __lshift__(self, input:str) -> torch.Tensor:
34
- "convert input string to batch of tokens"
35
- return self.batch_converter(input, return_tensors="pt")["input_ids"]
36
-
37
- def __getitem__(self, key:str) -> int:
38
- "get token ID from character"
39
- return self.alphabet[key]
40
-
41
- def run_model(self, data):
42
- "run model on data"
43
- def label_row(row, token_probs):
44
- "label row with score"
45
- wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
46
- score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
47
- return score.item()
48
-
49
- batch_tokens = self<<data.seq
50
-
51
- # run model with selected scoring strategy (info thereof available in the original ESM paper)
52
- if data.scoring_strategy.startswith("wt-marginals"):
53
- with torch.no_grad():
54
- token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
55
- data.out[self.model_name] = data.sub.apply(
56
- lambda row: label_row(
57
- row['0'],
58
- token_probs,
59
- ),
60
- axis=1,
61
- )
62
- elif data.scoring_strategy.startswith("masked-marginals"):
63
- all_token_probs = []
64
- for i in range(batch_tokens.size()[1]):
65
- batch_tokens_masked = batch_tokens.clone()
66
- batch_tokens_masked[0, i] = self['<mask>']
67
- with torch.no_grad():
68
- token_probs = torch.log_softmax(
69
- self>>batch_tokens_masked, dim=-1
70
- )
71
- all_token_probs.append(token_probs[:, i])
72
- token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
73
- data.out[self.model_name] = data.sub.apply(
74
- lambda row: label_row(
75
- row['0'],
76
- token_probs,
77
- ),
78
- axis=1,
79
- )
80
-
81
- class Data:
82
- """Container for input and output data"""
83
- # initialise empty model as static class member for efficiency
84
- model = Model()
85
-
86
- def parse_seq(self, src:str):
87
- "parse input sequence"
88
- self.seq = src.strip().upper()
89
- if not all(x in self.model.alphabet for x in src):
90
- raise RuntimeError("Unrecognised characters in sequence")
91
-
92
- def parse_sub(self, trg:str):
93
- "parse input substitutions"
94
- self.mode = None
95
- self.sub = list()
96
- self.trg = trg.strip().upper()
97
-
98
- # identify running mode
99
- if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode
100
- self.mode = 'SVS'
101
- for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1):
102
- if src != trg:
103
- self.sub.append(f"{src}{resi}{trg}")
104
- else:
105
- self.trg = self.trg.split()
106
- if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode
107
- self.mode = 'DMS'
108
- for resi in map(int, self.trg):
109
- src = self.seq[resi-1]
110
- for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
111
- self.sub.append(f"{src}{resi}{trg}")
112
- elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode
113
- self.mode = 'MUT'
114
- self.sub = self.trg
115
- else:
116
- raise RuntimeError("Unrecognised running mode; wrong inputs?")
117
-
118
- self.sub = pd.DataFrame(self.sub, columns=['0'])
119
-
120
- def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
121
- "initialise data"
122
- # if model has changed, load new model
123
- if self.model.model_name != model_name:
124
- self.model_name = model_name
125
- self.model = Model(model_name)
126
- self.parse_seq(src)
127
- self.parse_sub(trg)
128
- self.scoring_strategy = scoring_strategy
129
- self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
130
- self.out_buffer = out_file.name
131
-
132
- def parse_output(self) -> str:
133
- "format output data for visualisation"
134
- if self.mode == 'MUT': # if single substitution mode, sort by score
135
- self.out = self.out.sort_values(self.model_name, ascending=False)
136
- elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score
137
- self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence
138
- .sort_values(['resi', self.model_name], ascending=[True,False])
139
- .groupby(['resi'])
140
- .head(19)
141
- .drop(['resi'], axis=1)).iloc[19*x:19*(x+1)]
142
- .reset_index(drop=True) for x in range(self.out.shape[0]//19)]
143
- , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
144
- # save to temporary file to be downloaded
145
- self.out.round(2).to_csv(self.out_buffer, index=False)
146
- return (self.out.style
147
- .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
148
- .hide(axis=0)
149
- .hide(axis=1)
150
- .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
151
- .to_html(justify='center'))
152
-
153
- def calculate(self):
154
- "run model and parse output"
155
- self.model.run_model(self)
156
- return self.parse_output()
157
-
158
  def app(*argv):
159
- "run app"
160
- # seq, trg, model_name, scoring_strategy, out_file, *_ = argv
161
- # html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate()
162
- df = pd.DataFrame((pd.np.random.random((10, 5))-0.5)*10, columns=list('ABCDE'))
163
- df.to_csv(out_file.name, index=False)
164
- html = df.to_html(justify='center')
165
  return html, gr.File.update(value=out_file.name, visible=True)
166
 
167
  with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md:
 
1
+ from model import MODELS
2
+ from data import Data
3
  import gradio as gr
 
 
 
4
  from tempfile import NamedTemporaryFile
 
 
 
 
 
 
 
5
 
6
  # scoring strategies
7
  SCORING = ["masked-marginals (more accurate)", "wt-marginals (faster)"]
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def app(*argv):
10
+ seq, trg, model_name, scoring_strategy, out_file, *_ = argv
11
+ html = Data(seq, trg, model_name, scoring_strategy, out_file).calculate()
 
 
 
 
12
  return html, gr.File.update(value=out_file.name, visible=True)
13
 
14
  with gr.Blocks() as demo, NamedTemporaryFile(mode='w+', prefix='out_', suffix='.csv') as out_file, open("instructions.md", "r") as md:
data.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import Model
2
+ import pandas as pd
3
+ from re import match
4
+
5
+ class Data:
6
+ """Container for input and output data"""
7
+ # initialise empty model as static class member for efficiency
8
+ model = Model()
9
+
10
+ def parse_seq(self, src:str):
11
+ "parse input sequence"
12
+ self.seq = src.strip().upper()
13
+ if not all(x in self.model.alphabet for x in src):
14
+ raise RuntimeError("Unrecognised characters in sequence")
15
+
16
+ def parse_sub(self, trg:str):
17
+ "parse input substitutions"
18
+ self.mode = None
19
+ self.sub = list()
20
+ self.trg = trg.strip().upper()
21
+
22
+ # identify running mode
23
+ if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq): # if single string of same length as sequence, seq vs seq mode
24
+ self.mode = 'SVS'
25
+ for resi,(src,trg) in enumerate(zip(self.seq, self.trg), 1):
26
+ if src != trg:
27
+ self.sub.append(f"{src}{resi}{trg}")
28
+ else:
29
+ self.trg = self.trg.split()
30
+ if all(match(r'\d+', x) for x in self.trg): # if all strings are numbers, deep mutational scanning mode
31
+ self.mode = 'DMS'
32
+ for resi in map(int, self.trg):
33
+ src = self.seq[resi-1]
34
+ for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
35
+ self.sub.append(f"{src}{resi}{trg}")
36
+ elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): # if all strings are of the form X#Y, single substitution mode
37
+ self.mode = 'MUT'
38
+ self.sub = self.trg
39
+ else:
40
+ raise RuntimeError("Unrecognised running mode; wrong inputs?")
41
+
42
+ self.sub = pd.DataFrame(self.sub, columns=['0'])
43
+
44
+ def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
45
+ "initialise data"
46
+ # if model has changed, load new model
47
+ if self.model.model_name != model_name:
48
+ self.model_name = model_name
49
+ self.model = Model(model_name)
50
+ self.parse_seq(src)
51
+ self.parse_sub(trg)
52
+ self.scoring_strategy = scoring_strategy
53
+ self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
54
+ self.out_buffer = out_file.name
55
+
56
+ def parse_output(self) -> str:
57
+ "format output data for visualisation"
58
+ if self.mode == 'MUT': # if single substitution mode, sort by score
59
+ self.out = self.out.sort_values(self.model_name, ascending=False)
60
+ elif self.mode == 'DMS': # if deep mutational scanning mode, sort by residue and score
61
+ self.out = pd.concat([(self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) # FIX: this doesn't work if there's jolly characters in the input sequence
62
+ .sort_values(['resi', self.model_name], ascending=[True,False])
63
+ .groupby(['resi'])
64
+ .head(19)
65
+ .drop(['resi'], axis=1)).iloc[19*x:19*(x+1)]
66
+ .reset_index(drop=True) for x in range(self.out.shape[0]//19)]
67
+ , axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
68
+ # save to temporary file to be downloaded
69
+ self.out.round(2).to_csv(self.out_buffer, index=False)
70
+ return (self.out.style
71
+ .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
72
+ .hide(axis=0)
73
+ .hide(axis=1)
74
+ .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
75
+ .to_html(justify='center'))
76
+
77
+ def calculate(self):
78
+ "run model and parse output"
79
+ self.model.run_model(self)
80
+ return self.parse_output()
model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi, ModelFilter
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+
5
+ # fetch suitable ESM models from HuggingFace Hub
6
+ MODELS = [m.modelId for m in HfApi().list_models(filter=ModelFilter(author="facebook", model_name="esm", task="fill-mask"), sort="lastModified", direction=-1)]
7
+ if not any(MODELS):
8
+ raise RuntimeError("Error while retrieving models from HuggingFace Hub")
9
+
10
+ class Model:
11
+ """Wrapper for ESM models"""
12
+ def __init__(self, model_name:str=""):
13
+ "load selected model and tokenizer"
14
+ self.model_name = model_name
15
+ if model_name:
16
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
17
+ self.batch_converter = AutoTokenizer.from_pretrained(model_name)
18
+ self.alphabet = self.batch_converter.get_vocab()
19
+ if torch.cuda.is_available():
20
+ self.model = self.model.cuda()
21
+
22
+ def __rshift__(self, batch_tokens:torch.Tensor) -> torch.Tensor:
23
+ "run model on batch of tokens"
24
+ return self.model(batch_tokens)["logits"]
25
+
26
+ def __lshift__(self, input:str) -> torch.Tensor:
27
+ "convert input string to batch of tokens"
28
+ return self.batch_converter(input, return_tensors="pt")["input_ids"]
29
+
30
+ def __getitem__(self, key:str) -> int:
31
+ "get token ID from character"
32
+ return self.alphabet[key]
33
+
34
+ def run_model(self, data):
35
+ "run model on data"
36
+ def label_row(row, token_probs):
37
+ "label row with score"
38
+ wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
39
+ score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
40
+ return score.item()
41
+
42
+ batch_tokens = self<<data.seq
43
+
44
+ # run model with selected scoring strategy (info thereof available in the original ESM paper)
45
+ if data.scoring_strategy.startswith("wt-marginals"):
46
+ with torch.no_grad():
47
+ token_probs = torch.log_softmax(self>>batch_tokens, dim=-1)
48
+ data.out[self.model_name] = data.sub.apply(
49
+ lambda row: label_row(
50
+ row['0'],
51
+ token_probs,
52
+ ),
53
+ axis=1,
54
+ )
55
+ elif data.scoring_strategy.startswith("masked-marginals"):
56
+ all_token_probs = []
57
+ for i in range(batch_tokens.size()[1]):
58
+ batch_tokens_masked = batch_tokens.clone()
59
+ batch_tokens_masked[0, i] = self['<mask>']
60
+ with torch.no_grad():
61
+ token_probs = torch.log_softmax(
62
+ self>>batch_tokens_masked, dim=-1
63
+ )
64
+ all_token_probs.append(token_probs[:, i])
65
+ token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
66
+ data.out[self.model_name] = data.sub.apply(
67
+ lambda row: label_row(
68
+ row['0'],
69
+ token_probs,
70
+ ),
71
+ axis=1,
72
+ )