Spaces:
Runtime error
Runtime error
import torch | |
import warnings | |
from Bio import BiopythonWarning | |
from Bio.PDB import PDBIO | |
from Bio.PDB.StructureBuilder import StructureBuilder | |
from .constants import AA, restype_to_heavyatom_names | |
def save_pdb(data, path=None): | |
""" | |
Args: | |
data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`, | |
`pos_heavyatom`, `mask_heavyatom`. | |
""" | |
def _mask_select(v, mask): | |
if isinstance(v, str): | |
return ''.join([s for i, s in enumerate(v) if mask[i]]) | |
elif isinstance(v, list): | |
return [s for i, s in enumerate(v) if mask[i]] | |
elif isinstance(v, torch.Tensor): | |
return v[mask] | |
else: | |
return v | |
def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch): | |
builder.init_chain(chain_id_ch[0]) | |
builder.init_seg(' ') | |
for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \ | |
zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch): | |
if not AA.is_aa(aa_res.item()): | |
print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item())) | |
continue | |
restype = AA(aa_res.item()) | |
builder.init_residue( | |
resname = str(restype), | |
field = ' ', | |
resseq = resseq_res.item(), | |
icode = icode_res, | |
) | |
for i, atom_name in enumerate(restype_to_heavyatom_names[restype]): | |
if atom_name == '': continue # No expected atom | |
if (~mask_allatom_res[i]).any(): continue # Atom is missing | |
if len(atom_name) == 1: fullname = ' %s ' % atom_name | |
elif len(atom_name) == 2: fullname = ' %s ' % atom_name | |
elif len(atom_name) == 3: fullname = ' %s' % atom_name | |
else: fullname = atom_name # len == 4 | |
builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,) | |
warnings.simplefilter('ignore', BiopythonWarning) | |
builder = StructureBuilder() | |
builder.init_structure(0) | |
builder.init_model(0) | |
unique_chain_nb = data['chain_nb'].unique().tolist() | |
for ch_nb in unique_chain_nb: | |
mask = (data['chain_nb'] == ch_nb) | |
aa = _mask_select(data['aa'], mask) | |
pos_heavyatom = _mask_select(data['pos_heavyatom'], mask) | |
mask_heavyatom = _mask_select(data['mask_heavyatom'], mask) | |
chain_id = _mask_select(data['chain_id'], mask) | |
resseq = _mask_select(data['resseq'], mask) | |
icode = _mask_select(data['icode'], mask) | |
_build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode) | |
structure = builder.get_structure() | |
if path is not None: | |
io = PDBIO() | |
io.set_structure(structure) | |
io.save(path) | |
return structure | |