torchdrug / utils.py
jannisborn's picture
Duplicate from jannisborn/gt4sd-moler
7d76d6f
raw
history blame
No virus
1.18 kB
import json
import logging
import os
from collections import defaultdict
from typing import Dict, List, Tuple
import mols2grid
import pandas as pd
from rdkit import Chem
from terminator.selfies import decoder
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def draw_grid_generate(
seeds: List[str],
samples: List[str],
n_cols: int = 3,
size=(140, 200),
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples.
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
result = defaultdict(list)
result.update(
{
"SMILES": seeds + samples,
"Name": [f"Seed_{i}" for i in range(len(seeds))]
+ [f"Generated_{i}" for i in range(len(samples))],
},
)
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data