|
import glob |
|
import itertools |
|
from pathlib import Path |
|
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable |
|
import string |
|
|
|
import numpy as np |
|
import torch |
|
from scipy.spatial.distance import squareform, pdist, cdist |
|
from Bio import SeqIO |
|
|
|
|
|
|
|
from tqdm import tqdm |
|
import pandas as pd |
|
|
|
|
|
|
|
deletekeys = dict.fromkeys(string.ascii_lowercase) |
|
deletekeys["."] = None |
|
deletekeys["*"] = None |
|
translation = str.maketrans(deletekeys) |
|
|
|
|
|
def read_sequence(filename: str) -> Tuple[str, str]: |
|
""" Reads the first (reference) sequences from a fasta or MSA file.""" |
|
record = next(SeqIO.parse(filename, "fasta")) |
|
return record.description, str(record.seq) |
|
|
|
def remove_insertions(sequence: str) -> str: |
|
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ |
|
return sequence.translate(translation) |
|
|
|
def read_msa(filename: str) -> List[Tuple[str, str]]: |
|
""" Reads the sequences from an MSA file, automatically removes insertions.""" |
|
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] |
|
|
|
|
|
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]: |
|
""" |
|
Select sequences from the MSA to maximize the hamming distance |
|
Alternatively, can use hhfilter |
|
""" |
|
assert mode in ("max", "min") |
|
if len(msa) <= num_seqs: |
|
return msa |
|
|
|
array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8) |
|
|
|
optfunc = np.argmax if mode == "max" else np.argmin |
|
all_indices = np.arange(len(msa)) |
|
indices = [0] |
|
pairwise_distances = np.zeros((0, len(msa))) |
|
for _ in range(num_seqs - 1): |
|
dist = cdist(array[indices[-1:]], array, "hamming") |
|
pairwise_distances = np.concatenate([pairwise_distances, dist]) |
|
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) |
|
shifted_index = optfunc(shifted_distance) |
|
index = np.delete(all_indices, indices)[shifted_index] |
|
indices.append(index) |
|
indices = sorted(indices) |
|
return [msa[idx] for idx in indices] |