SVM / msa.py
abondrn's picture
Added msa.py utils
4d84fae
import glob
import itertools
from pathlib import Path
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import string
import numpy as np
import torch
from scipy.spatial.distance import squareform, pdist, cdist
from Bio import SeqIO
#import biotite.structure as bs
#from biotite.structure.io.pdbx import PDBxFile, get_structure
#from biotite.database import rcsb
from tqdm import tqdm
import pandas as pd
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)
def read_sequence(filename: str) -> Tuple[str, str]:
""" Reads the first (reference) sequences from a fasta or MSA file."""
record = next(SeqIO.parse(filename, "fasta"))
return record.description, str(record.seq)
def remove_insertions(sequence: str) -> str:
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
return sequence.translate(translation)
def read_msa(filename: str) -> List[Tuple[str, str]]:
""" Reads the sequences from an MSA file, automatically removes insertions."""
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
"""
Select sequences from the MSA to maximize the hamming distance
Alternatively, can use hhfilter
"""
assert mode in ("max", "min")
if len(msa) <= num_seqs:
return msa
array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)
optfunc = np.argmax if mode == "max" else np.argmin
all_indices = np.arange(len(msa))
indices = [0]
pairwise_distances = np.zeros((0, len(msa)))
for _ in range(num_seqs - 1):
dist = cdist(array[indices[-1:]], array, "hamming")
pairwise_distances = np.concatenate([pairwise_distances, dist])
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
shifted_index = optfunc(shifted_distance)
index = np.delete(all_indices, indices)[shifted_index]
indices.append(index)
indices = sorted(indices)
return [msa[idx] for idx in indices]