Spaces:
Runtime error
Runtime error
import torch | |
from Bio.PDB import Selection | |
from Bio.PDB.Residue import Residue | |
from easydict import EasyDict | |
from .constants import ( | |
AA, max_num_heavyatoms, | |
restype_to_heavyatom_names, | |
BBHeavyAtom | |
) | |
class ParsingException(Exception): | |
pass | |
def _get_residue_heavyatom_info(res: Residue): | |
pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float) | |
mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool) | |
restype = AA(res.get_resname()) | |
for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]): | |
if atom_name == '': continue | |
if atom_name in res: | |
pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype) | |
mask_heavyatom[idx] = True | |
return pos_heavyatom, mask_heavyatom | |
def parse_biopython_structure(entity, unknown_threshold=1.0, max_resseq=None): | |
chains = Selection.unfold_entities(entity, 'C') | |
chains.sort(key=lambda c: c.get_id()) | |
data = EasyDict({ | |
'chain_id': [], | |
'resseq': [], 'icode': [], 'res_nb': [], | |
'aa': [], | |
'pos_heavyatom': [], 'mask_heavyatom': [], | |
}) | |
tensor_types = { | |
'resseq': torch.LongTensor, | |
'res_nb': torch.LongTensor, | |
'aa': torch.LongTensor, | |
'pos_heavyatom': torch.stack, | |
'mask_heavyatom': torch.stack, | |
} | |
count_aa, count_unk = 0, 0 | |
for i, chain in enumerate(chains): | |
seq_this = 0 # Renumbering residues | |
residues = Selection.unfold_entities(chain, 'R') | |
residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2])) # Sort residues by resseq-icode | |
for _, res in enumerate(residues): | |
resseq_this = int(res.get_id()[1]) | |
if max_resseq is not None and resseq_this > max_resseq: | |
continue | |
resname = res.get_resname() | |
if not AA.is_aa(resname): continue | |
if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue | |
restype = AA(resname) | |
count_aa += 1 | |
if restype == AA.UNK: | |
count_unk += 1 | |
continue | |
# Chain info | |
data.chain_id.append(chain.get_id()) | |
# Residue types | |
data.aa.append(restype) # Will be automatically cast to torch.long | |
# Heavy atoms | |
pos_heavyatom, mask_heavyatom = _get_residue_heavyatom_info(res) | |
data.pos_heavyatom.append(pos_heavyatom) | |
data.mask_heavyatom.append(mask_heavyatom) | |
# Sequential number | |
resseq_this = int(res.get_id()[1]) | |
icode_this = res.get_id()[2] | |
if seq_this == 0: | |
seq_this = 1 | |
else: | |
d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item() | |
if d_CA_CA <= 4.0: | |
seq_this += 1 | |
else: | |
d_resseq = resseq_this - data.resseq[-1] | |
seq_this += max(2, d_resseq) | |
data.resseq.append(resseq_this) | |
data.icode.append(icode_this) | |
data.res_nb.append(seq_this) | |
if len(data.aa) == 0: | |
raise ParsingException('No parsed residues.') | |
if (count_unk / count_aa) >= unknown_threshold: | |
raise ParsingException( | |
f'Too many unknown residues, threshold {unknown_threshold:.2f}.' | |
) | |
seq_map = {} | |
for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)): | |
seq_map[(chain_id, resseq, icode)] = i | |
for key, convert_fn in tensor_types.items(): | |
data[key] = convert_fn(data[key]) | |
return data, seq_map | |