luost26's picture
Update
753e275
raw
history blame
3.77 kB
import torch
from Bio.PDB import Selection
from Bio.PDB.Residue import Residue
from easydict import EasyDict
from .constants import (
AA, max_num_heavyatoms,
restype_to_heavyatom_names,
BBHeavyAtom
)
class ParsingException(Exception):
pass
def _get_residue_heavyatom_info(res: Residue):
pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float)
mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool)
restype = AA(res.get_resname())
for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]):
if atom_name == '': continue
if atom_name in res:
pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)
mask_heavyatom[idx] = True
return pos_heavyatom, mask_heavyatom
def parse_biopython_structure(entity, unknown_threshold=1.0, max_resseq=None):
chains = Selection.unfold_entities(entity, 'C')
chains.sort(key=lambda c: c.get_id())
data = EasyDict({
'chain_id': [],
'resseq': [], 'icode': [], 'res_nb': [],
'aa': [],
'pos_heavyatom': [], 'mask_heavyatom': [],
})
tensor_types = {
'resseq': torch.LongTensor,
'res_nb': torch.LongTensor,
'aa': torch.LongTensor,
'pos_heavyatom': torch.stack,
'mask_heavyatom': torch.stack,
}
count_aa, count_unk = 0, 0
for i, chain in enumerate(chains):
seq_this = 0 # Renumbering residues
residues = Selection.unfold_entities(chain, 'R')
residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2])) # Sort residues by resseq-icode
for _, res in enumerate(residues):
resseq_this = int(res.get_id()[1])
if max_resseq is not None and resseq_this > max_resseq:
continue
resname = res.get_resname()
if not AA.is_aa(resname): continue
if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue
restype = AA(resname)
count_aa += 1
if restype == AA.UNK:
count_unk += 1
continue
# Chain info
data.chain_id.append(chain.get_id())
# Residue types
data.aa.append(restype) # Will be automatically cast to torch.long
# Heavy atoms
pos_heavyatom, mask_heavyatom = _get_residue_heavyatom_info(res)
data.pos_heavyatom.append(pos_heavyatom)
data.mask_heavyatom.append(mask_heavyatom)
# Sequential number
resseq_this = int(res.get_id()[1])
icode_this = res.get_id()[2]
if seq_this == 0:
seq_this = 1
else:
d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item()
if d_CA_CA <= 4.0:
seq_this += 1
else:
d_resseq = resseq_this - data.resseq[-1]
seq_this += max(2, d_resseq)
data.resseq.append(resseq_this)
data.icode.append(icode_this)
data.res_nb.append(seq_this)
if len(data.aa) == 0:
raise ParsingException('No parsed residues.')
if (count_unk / count_aa) >= unknown_threshold:
raise ParsingException(
f'Too many unknown residues, threshold {unknown_threshold:.2f}.'
)
seq_map = {}
for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)):
seq_map[(chain_id, resseq, icode)] = i
for key, convert_fn in tensor_types.items():
data[key] = convert_fn(data[key])
return data, seq_map