import torch import warnings from Bio import BiopythonWarning from Bio.PDB import PDBIO from Bio.PDB.StructureBuilder import StructureBuilder from .constants import AA, restype_to_heavyatom_names def save_pdb(data, path=None): """ Args: data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`, `pos_heavyatom`, `mask_heavyatom`. """ def _mask_select(v, mask): if isinstance(v, str): return ''.join([s for i, s in enumerate(v) if mask[i]]) elif isinstance(v, list): return [s for i, s in enumerate(v) if mask[i]] elif isinstance(v, torch.Tensor): return v[mask] else: return v def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch): builder.init_chain(chain_id_ch[0]) builder.init_seg(' ') for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \ zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch): if not AA.is_aa(aa_res.item()): print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item())) continue restype = AA(aa_res.item()) builder.init_residue( resname = str(restype), field = ' ', resseq = resseq_res.item(), icode = icode_res, ) for i, atom_name in enumerate(restype_to_heavyatom_names[restype]): if atom_name == '': continue # No expected atom if (~mask_allatom_res[i]).any(): continue # Atom is missing if len(atom_name) == 1: fullname = ' %s ' % atom_name elif len(atom_name) == 2: fullname = ' %s ' % atom_name elif len(atom_name) == 3: fullname = ' %s' % atom_name else: fullname = atom_name # len == 4 builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,) warnings.simplefilter('ignore', BiopythonWarning) builder = StructureBuilder() builder.init_structure(0) builder.init_model(0) unique_chain_nb = data['chain_nb'].unique().tolist() for ch_nb in unique_chain_nb: mask = (data['chain_nb'] == ch_nb) aa = _mask_select(data['aa'], mask) pos_heavyatom = _mask_select(data['pos_heavyatom'], mask) mask_heavyatom = _mask_select(data['mask_heavyatom'], mask) chain_id = _mask_select(data['chain_id'], mask) resseq = _mask_select(data['resseq'], mask) icode = _mask_select(data['icode'], mask) _build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode) structure = builder.get_structure() if path is not None: io = PDBIO() io.set_structure(structure) io.save(path) return structure