Added msa.py utils
Browse files
app.py
CHANGED
@@ -64,12 +64,17 @@ def msa_embed(msa):
|
|
64 |
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
|
65 |
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
|
66 |
|
67 |
-
|
|
|
68 |
temp = temp[12][:,:,0,:]
|
69 |
temp = torch.mean(temp,(0,1))
|
70 |
return temp
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
73 |
def download_data_if_required():
|
74 |
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
|
75 |
fps = [pg.trained_model_fp]
|
|
|
64 |
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
|
65 |
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
|
66 |
|
67 |
+
with torch.no_grad():
|
68 |
+
temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
|
69 |
temp = temp[12][:,:,0,:]
|
70 |
temp = torch.mean(temp,(0,1))
|
71 |
return temp
|
72 |
|
73 |
|
74 |
+
def go_embed(terms):
|
75 |
+
pass
|
76 |
+
|
77 |
+
|
78 |
def download_data_if_required():
|
79 |
url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
|
80 |
fps = [pg.trained_model_fp]
|
msa.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import itertools
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
|
5 |
+
import string
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from scipy.spatial.distance import squareform, pdist, cdist
|
10 |
+
from Bio import SeqIO
|
11 |
+
#import biotite.structure as bs
|
12 |
+
#from biotite.structure.io.pdbx import PDBxFile, get_structure
|
13 |
+
#from biotite.database import rcsb
|
14 |
+
from tqdm import tqdm
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
|
18 |
+
# This is an efficient way to delete lowercase characters and insertion characters from a string
|
19 |
+
deletekeys = dict.fromkeys(string.ascii_lowercase)
|
20 |
+
deletekeys["."] = None
|
21 |
+
deletekeys["*"] = None
|
22 |
+
translation = str.maketrans(deletekeys)
|
23 |
+
|
24 |
+
|
25 |
+
def read_sequence(filename: str) -> Tuple[str, str]:
|
26 |
+
""" Reads the first (reference) sequences from a fasta or MSA file."""
|
27 |
+
record = next(SeqIO.parse(filename, "fasta"))
|
28 |
+
return record.description, str(record.seq)
|
29 |
+
|
30 |
+
def remove_insertions(sequence: str) -> str:
|
31 |
+
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
|
32 |
+
return sequence.translate(translation)
|
33 |
+
|
34 |
+
def read_msa(filename: str) -> List[Tuple[str, str]]:
|
35 |
+
""" Reads the sequences from an MSA file, automatically removes insertions."""
|
36 |
+
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
|
37 |
+
|
38 |
+
|
39 |
+
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
|
40 |
+
"""
|
41 |
+
Select sequences from the MSA to maximize the hamming distance
|
42 |
+
Alternatively, can use hhfilter
|
43 |
+
"""
|
44 |
+
assert mode in ("max", "min")
|
45 |
+
if len(msa) <= num_seqs:
|
46 |
+
return msa
|
47 |
+
|
48 |
+
array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)
|
49 |
+
|
50 |
+
optfunc = np.argmax if mode == "max" else np.argmin
|
51 |
+
all_indices = np.arange(len(msa))
|
52 |
+
indices = [0]
|
53 |
+
pairwise_distances = np.zeros((0, len(msa)))
|
54 |
+
for _ in range(num_seqs - 1):
|
55 |
+
dist = cdist(array[indices[-1:]], array, "hamming")
|
56 |
+
pairwise_distances = np.concatenate([pairwise_distances, dist])
|
57 |
+
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
|
58 |
+
shifted_index = optfunc(shifted_distance)
|
59 |
+
index = np.delete(all_indices, indices)[shifted_index]
|
60 |
+
indices.append(index)
|
61 |
+
indices = sorted(indices)
|
62 |
+
return [msa[idx] for idx in indices]
|