NeptuniaNep commited on
Commit
a552ae2
β€’
1 Parent(s): a9799e9
Files changed (5) hide show
  1. README.md +11 -5
  2. app.py +272 -73
  3. gitattributes.txt +34 -0
  4. msa.py +62 -0
  5. requirements.txt +12 -4
README.md CHANGED
@@ -1,12 +1,18 @@
1
  ---
 
2
  title: SVM
3
- emoji: πŸ”₯
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.21.0
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ # https://huggingface.co/docs/hub/spaces-config-reference
3
  title: SVM
4
+ emoji: 🧬
5
+ colorFrom: green
6
+ colorTo: green
7
+ sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
+ models:
11
+ - InstaDeepAI/nucleotide-transformer-500m-1000g
12
+ - facebook/esmfold_v1
13
+ - sentence-transformers/all-mpnet-base-v2
14
+ python_version: 3.10.4
15
+ license: mit
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,76 +1,275 @@
 
 
 
 
 
 
 
 
 
1
  import torch
2
- import streamlit as st
3
- from transformers import AutoTokenizer, OPTForCausalLM
4
-
5
-
6
- @st.cache_resource
7
- def load_model():
8
- tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-30b")
9
- model = OPTForCausalLM.from_pretrained("facebook/galactica-30b", device_map='auto', low_cpu_mem_usage=True, torch_dtype=torch.float16)
10
- model.gradient_checkpointing_enable()
11
- return tokenizer, model
12
-
13
-
14
- st.set_page_config(
15
- page_title='BioML-SVM',
16
- layout="wide"
17
- )
18
-
19
- with st.spinner("Loading Models and Tokens..."):
20
- tokenizer, model = load_model()
21
-
22
- with st.form(key='my_form'):
23
- col1, col2 = st.columns([10, 1])
24
- text_input = col1.text_input(label='Enter the amino sequence')
25
- with col2:
26
- st.text('')
27
- st.text('')
28
- submit_button = st.form_submit_button(label='Submit')
29
-
30
- if submit_button:
31
- st.session_state['result_done'] = False
32
- # input_text = "[START_AMINO]GHMQSITAGQKVISKHKNGRFYQCEVVRLTTETFYEVNFDDGSFSDNLYPEDIVSQDCLQFGPPAEGEVVQVRWTDGQVYGAKFVASHPIQMYQVEFEDGSQLVVKRDDVYTLDEELP[END_AMINO]"
33
- with st.spinner('Generating...'):
34
- # formatted_text = f"[START_AMINO]{text_input}[END_AMINO]"
35
- # formatted_text = f"Here is the sequence: [START_AMINO]{text_input}[END_AMINO]"
36
- formatted_text = f"{text_input}"
37
- input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda")
38
- outputs = model.generate(
39
- input_ids=input_ids,
40
- max_new_tokens=500
41
- )
42
- result = tokenizer.decode(outputs[0]).replace(formatted_text, "")
43
- st.markdown(result)
44
-
45
- if 'result_done' not in st.session_state or not st.session_state.result_done:
46
- st.session_state['result_done'] = True
47
- st.session_state['previous_state'] = result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  else:
49
- if 'result_done' in st.session_state and st.session_state.result_done:
50
- st.markdown(st.session_state.previous_state)
51
-
52
- if 'result_done' in st.session_state and st.session_state.result_done:
53
- with st.form(key='ask_more'):
54
- col1, col2 = st.columns([10, 1])
55
- text_input = col1.text_input(label='Ask more question')
56
- with col2:
57
- st.text('')
58
- st.text('')
59
- submit_button = st.form_submit_button(label='Submit')
60
-
61
- if submit_button:
62
- with st.spinner('Generating...'):
63
- # formatted_text = f"[START_AMINO]{text_input}[END_AMINO]"
64
- formatted_text = f"Q:{text_input}\n\nA:\n\n"
65
- input_ids = tokenizer(formatted_text, return_tensors="pt").input_ids.to("cuda")
66
-
67
- outputs = model.generate(
68
- input_ids=input_ids,
69
- max_length=len(formatted_text) + 500,
70
- do_sample=True,
71
- top_k=40,
72
- num_beams=1,
73
- num_return_sequences=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
- result = tokenizer.decode(outputs[0]).replace(formatted_text, "")
76
- st.markdown(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
2
+ from typing import Tuple
3
+ import os
4
+ import sys
5
+ from urllib import request
6
+
7
+ import gradio as gr
8
+ import requests
9
+ from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
10
  import torch
11
+ import progres as pg
12
+ import esm
13
+
14
+ import msa
15
+
16
+
17
+ tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
18
+ model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
19
+ model_nt.eval()
20
+
21
+ tokenizer_aa = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
22
+ model_aa = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
23
+ model_aa.eval()
24
+
25
+ tokenizer_se = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
26
+ model_se = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
27
+ model_se.eval()
28
+
29
+ msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
30
+ msa_transformer = msa_transformer.eval()
31
+ msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()
32
+
33
+
34
+
35
+ def nt_embed(sequence: str):
36
+ tokens_ids = tokenizer_nt.batch_encode_plus([sequence], return_tensors="pt")["input_ids"]
37
+ attention_mask = tokens_ids != tokenizer_nt.pad_token_id
38
+ with torch.no_grad():
39
+ torch_outs = model_nt(
40
+ tokens_ids,#.to('cuda'),
41
+ attention_mask=attention_mask,#.to('cuda'),
42
+ output_hidden_states=True
43
+ )
44
+ last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0]
45
+ return last_layer_CLS
46
+
47
+
48
+ def aa_embed(sequence: str):
49
+ tokens = tokenizer_aa([sequence], return_tensors="pt")
50
+ with torch.no_grad():
51
+ torch_outs = model_aa(**tokens)
52
+ return torch_outs[0]
53
+
54
+
55
+ def se_embed(sentence: str):
56
+ encoded_input = tokenizer_se([sentence], return_tensors='pt')
57
+ with torch.no_grad():
58
+ model_output = model_se(**encoded_input)
59
+ return model_output[0]
60
+
61
+
62
+ def msa_embed(sequences: list):
63
+ inputs = msa.greedy_select(sequences, num_seqs=128) # can change this to pass more/fewer sequences
64
+ msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
65
+ msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
66
+
67
+ with torch.no_grad():
68
+ temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
69
+ temp = temp[12][:,:,0,:]
70
+ temp = torch.mean(temp,(0,1))
71
+ return temp
72
+
73
+
74
+ def go_embed(terms):
75
+ pass
76
+
77
+
78
+ def download_data_if_required():
79
+ url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
80
+ fps = [pg.trained_model_fp]
81
+ urls = [f"{url_base}/trained_model.pt"]
82
+ #for targetdb in pre_embedded_dbs:
83
+ # fps.append(os.path.join(database_dir, targetdb + ".pt"))
84
+ # urls.append(f"{url_base}/{targetdb}.pt")
85
+
86
+ if not os.path.isdir(pg.trained_model_dir):
87
+ os.makedirs(pg.trained_model_dir)
88
+ #if not os.path.isdir(database_dir):
89
+ # os.makedirs(database_dir)
90
+
91
+ printed = False
92
+ for fp, url in zip(fps, urls):
93
+ if not os.path.isfile(fp):
94
+ if not printed:
95
+ print("Downloading data as first time setup (~340 MB) to ", pg.progres_dir,
96
+ ", internet connection required, this can take a few minutes",
97
+ sep="", file=sys.stderr)
98
+ printed = True
99
+ try:
100
+ request.urlretrieve(url, fp)
101
+ d = torch.load(fp, map_location="cpu")
102
+ if fp == pg.trained_model_fp:
103
+ assert "model" in d
104
+ else:
105
+ assert "embeddings" in d
106
+ except:
107
+ if os.path.isfile(fp):
108
+ os.remove(fp)
109
+ print("Failed to download from", url, "and save to", fp, file=sys.stderr)
110
+ print("Exiting", file=sys.stderr)
111
+ sys.exit(1)
112
+
113
+ if printed:
114
+ print("Data downloaded successfully", file=sys.stderr)
115
+
116
+
117
+ def get_pdb(pdb_code="", filepath=""):
118
+ if pdb_code is None or pdb_code == "":
119
+ try:
120
+ with open(filepath.name) as f:
121
+ return f.read()
122
+ except AttributeError as e:
123
+ return None
124
  else:
125
+ return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode()
126
+
127
+
128
+ def molecule(pdb):
129
+
130
+ x = (
131
+ """<!DOCTYPE html>
132
+ <html>
133
+ <head>
134
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
135
+ <style>
136
+ body{
137
+ font-family:sans-serif
138
+ }
139
+ .mol-container {
140
+ width: 100%;
141
+ height: 600px;
142
+ position: relative;
143
+ }
144
+ .mol-container select{
145
+ background-image:None;
146
+ }
147
+ </style>
148
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
149
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
150
+ </head>
151
+ <body>
152
+ <div id="container" class="mol-container"></div>
153
+
154
+ <script>
155
+ let pdb = `"""
156
+ + pdb
157
+ + """`
158
+
159
+ $(document).ready(function () {
160
+ let element = $("#container");
161
+ let config = { backgroundColor: "black" };
162
+ let viewer = $3Dmol.createViewer(element, config);
163
+ viewer.addModel(pdb, "pdb");
164
+ viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } });
165
+ viewer.addSurface("MS", { opacity: .5, color: "white" });
166
+ viewer.zoomTo();
167
+ viewer.render();
168
+ viewer.zoom(0.8, 2000);
169
+ })
170
+ </script>
171
+ </body></html>"""
172
+ )
173
+
174
+ return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
175
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
176
+ allow-scripts allow-same-origin allow-popups
177
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
178
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
179
+
180
+
181
+ def str2coords(s):
182
+ coords = []
183
+ for line in s.split('\n'):
184
+ if (line.startswith("ATOM ") or line.startswith("HETATM")) and line[12:16].strip() == "CA":
185
+ coords.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
186
+ elif line.startswith("ENDMDL"):
187
+ break
188
+ return coords
189
+
190
+
191
+ def update_st(inp, file):
192
+ pdb = get_pdb(inp, file)
193
+ return (molecule(pdb), pg.embed_coords(str2coords(pdb)))
194
+
195
+
196
+ def update_nt(inp):
197
+ return str(nt_embed(inp or ''))
198
+
199
+
200
+ def update_aa(inp):
201
+ return str(aa_embed(inp))
202
+
203
+
204
+ def update_se(inp):
205
+ return str(se_embed(inp))
206
+
207
+
208
+ def update_go(inp):
209
+ return str(go_embed(inp))
210
+
211
+
212
+ def update_msa(inp):
213
+ return str(msa_embed(msa.read_msa(inp.name)))
214
+
215
+
216
+ demo = gr.Blocks()
217
+
218
+ with demo:
219
+ with gr.Tabs():
220
+ with gr.TabItem("PDB Structural Embeddings"):
221
+ with gr.Row():
222
+ with gr.Box():
223
+ inp = gr.Textbox(
224
+ placeholder="PDB Code or upload file below", label="Input structure"
225
+ )
226
+ file = gr.File(file_count="single")
227
+ gr.Examples(["2CBA", "6VXX"], inp)
228
+ btn = gr.Button("View structure")
229
+ gr.Markdown("# PDB viewer using 3Dmol.js")
230
+ mol = gr.HTML()
231
+ emb = gr.Textbox(interactive=False)
232
+ btn.click(fn=update_st, inputs=[inp, file], outputs=[mol, emb])
233
+ with gr.TabItem("Nucleotide Sequence Embeddings"):
234
+ with gr.Box():
235
+ inp = gr.Textbox(
236
+ placeholder="ATCGCTGCCCGTAGATAATAAGAGACACTGAGGCC", label="Input Nucleotide Sequence"
237
+ )
238
+ btn = gr.Button("View embeddings")
239
+ emb = gr.Textbox(interactive=False)
240
+ btn.click(fn=update_nt, inputs=[inp], outputs=emb)
241
+ with gr.TabItem("Amino Acid Sequence Embeddings"):
242
+ with gr.Box():
243
+ inp = gr.Textbox(
244
+ placeholder="AAGQCYRGRCSGGLCCSKYGYCGSGPAYCG", label="Input Amino Acid Sequence"
245
+ )
246
+ btn = gr.Button("View embeddings")
247
+ emb = gr.Textbox(interactive=False)
248
+ btn.click(fn=update_aa, inputs=[inp], outputs=emb)
249
+ with gr.TabItem("Sentence Embeddings"):
250
+ with gr.Box():
251
+ inp = gr.Textbox(
252
+ placeholder="Your text here", label="Input Sentence"
253
  )
254
+ btn = gr.Button("View embeddings")
255
+ emb = gr.Textbox(interactive=False)
256
+ btn.click(fn=update_se, inputs=[inp], outputs=emb)
257
+ with gr.TabItem("MSA Embeddings"):
258
+ with gr.Box():
259
+ inp = gr.File(file_count="single", label="Input MSA")
260
+ btn = gr.Button("View embeddings")
261
+ emb = gr.Textbox(interactive=False)
262
+ btn.click(fn=update_msa, inputs=[inp], outputs=emb)
263
+ with gr.TabItem("GO Embeddings"):
264
+ with gr.Box():
265
+ inp = gr.Textbox(
266
+ placeholder="", label="Input GO Terms"
267
+ )
268
+ btn = gr.Button("View embeddings")
269
+ emb = gr.Textbox(interactive=False)
270
+ btn.click(fn=update_go, inputs=[inp], outputs=emb)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ download_data_if_required()
275
+ demo.launch()
gitattributes.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
msa.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import itertools
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
5
+ import string
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy.spatial.distance import squareform, pdist, cdist
10
+ from Bio import SeqIO
11
+ #import biotite.structure as bs
12
+ #from biotite.structure.io.pdbx import PDBxFile, get_structure
13
+ #from biotite.database import rcsb
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+
18
+ # This is an efficient way to delete lowercase characters and insertion characters from a string
19
+ deletekeys = dict.fromkeys(string.ascii_lowercase)
20
+ deletekeys["."] = None
21
+ deletekeys["*"] = None
22
+ translation = str.maketrans(deletekeys)
23
+
24
+
25
+ def read_sequence(filename: str) -> Tuple[str, str]:
26
+ """ Reads the first (reference) sequences from a fasta or MSA file."""
27
+ record = next(SeqIO.parse(filename, "fasta"))
28
+ return record.description, str(record.seq)
29
+
30
+ def remove_insertions(sequence: str) -> str:
31
+ """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
32
+ return sequence.translate(translation)
33
+
34
+ def read_msa(filename: str) -> List[Tuple[str, str]]:
35
+ """ Reads the sequences from an MSA file, automatically removes insertions."""
36
+ return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
37
+
38
+
39
+ def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
40
+ """
41
+ Select sequences from the MSA to maximize the hamming distance
42
+ Alternatively, can use hhfilter
43
+ """
44
+ assert mode in ("max", "min")
45
+ if len(msa) <= num_seqs:
46
+ return msa
47
+
48
+ array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)
49
+
50
+ optfunc = np.argmax if mode == "max" else np.argmin
51
+ all_indices = np.arange(len(msa))
52
+ indices = [0]
53
+ pairwise_distances = np.zeros((0, len(msa)))
54
+ for _ in range(num_seqs - 1):
55
+ dist = cdist(array[indices[-1:]], array, "hamming")
56
+ pairwise_distances = np.concatenate([pairwise_distances, dist])
57
+ shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
58
+ shifted_index = optfunc(shifted_distance)
59
+ index = np.delete(all_indices, indices)[shifted_index]
60
+ indices.append(index)
61
+ indices = sorted(indices)
62
+ return [msa[idx] for idx in indices]
requirements.txt CHANGED
@@ -1,5 +1,13 @@
1
- transformers
2
  accelerate
3
- streamlit
4
- # bitsandbytes
5
- # scipy
 
 
 
 
 
 
 
 
 
 
 
1
  accelerate
2
+ gradio==3.33.1
3
+ --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html pyg-lib==0.2.0+pt20
4
+ requests==2.31.0
5
+ torch==2.0.1
6
+ --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-cluster==1.6.1
7
+ torch-geometric==2.3.1
8
+ torch-scatter==2.1.1
9
+ --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-sparse==0.6.17
10
+ --find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html torch-spline-conv==1.2.2
11
+ transformers==4.29.2
12
+ progres
13
+ fair-esm