luost26's picture
Update
753e275
raw
history blame
2.97 kB
import torch
import warnings
from Bio import BiopythonWarning
from Bio.PDB import PDBIO
from Bio.PDB.StructureBuilder import StructureBuilder
from .constants import AA, restype_to_heavyatom_names
def save_pdb(data, path=None):
"""
Args:
data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`,
`pos_heavyatom`, `mask_heavyatom`.
"""
def _mask_select(v, mask):
if isinstance(v, str):
return ''.join([s for i, s in enumerate(v) if mask[i]])
elif isinstance(v, list):
return [s for i, s in enumerate(v) if mask[i]]
elif isinstance(v, torch.Tensor):
return v[mask]
else:
return v
def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch):
builder.init_chain(chain_id_ch[0])
builder.init_seg(' ')
for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \
zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch):
if not AA.is_aa(aa_res.item()):
print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item()))
continue
restype = AA(aa_res.item())
builder.init_residue(
resname = str(restype),
field = ' ',
resseq = resseq_res.item(),
icode = icode_res,
)
for i, atom_name in enumerate(restype_to_heavyatom_names[restype]):
if atom_name == '': continue # No expected atom
if (~mask_allatom_res[i]).any(): continue # Atom is missing
if len(atom_name) == 1: fullname = ' %s ' % atom_name
elif len(atom_name) == 2: fullname = ' %s ' % atom_name
elif len(atom_name) == 3: fullname = ' %s' % atom_name
else: fullname = atom_name # len == 4
builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,)
warnings.simplefilter('ignore', BiopythonWarning)
builder = StructureBuilder()
builder.init_structure(0)
builder.init_model(0)
unique_chain_nb = data['chain_nb'].unique().tolist()
for ch_nb in unique_chain_nb:
mask = (data['chain_nb'] == ch_nb)
aa = _mask_select(data['aa'], mask)
pos_heavyatom = _mask_select(data['pos_heavyatom'], mask)
mask_heavyatom = _mask_select(data['mask_heavyatom'], mask)
chain_id = _mask_select(data['chain_id'], mask)
resseq = _mask_select(data['resseq'], mask)
icode = _mask_select(data['icode'], mask)
_build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode)
structure = builder.get_structure()
if path is not None:
io = PDBIO()
io.set_structure(structure)
io.save(path)
return structure