geodiff / utils.py
jannisborn's picture
Duplicate from jannisborn/gt4sd-diffusers
5da68a0
import logging
from collections import defaultdict
from typing import List
import mols2grid
import pandas as pd
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def draw_grid_generate(
samples: List[str],
seeds: List[str] = [],
n_cols: int = 3,
size=(140, 200),
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples.
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
result = defaultdict(list)
result.update(
{
"SMILES": seeds + samples,
"Name": [f"Seed_{i}" for i in range(len(seeds))]
+ [f"Generated_{i}" for i in range(len(samples))],
},
)
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data