|
import logging |
|
from collections import defaultdict |
|
from typing import List |
|
|
|
import mols2grid |
|
import pandas as pd |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
|
|
def draw_grid_generate( |
|
samples: List[str], |
|
seeds: List[str] = [], |
|
n_cols: int = 3, |
|
size=(140, 200), |
|
) -> str: |
|
""" |
|
Uses mols2grid to draw a HTML grid for the generated molecules |
|
|
|
Args: |
|
samples: The generated samples. |
|
n_cols: Number of columns in grid. Defaults to 5. |
|
size: Size of molecule in grid. Defaults to (140, 200). |
|
|
|
Returns: |
|
HTML to display |
|
""" |
|
|
|
result = defaultdict(list) |
|
result.update( |
|
{ |
|
"SMILES": seeds + samples, |
|
"Name": [f"Seed_{i}" for i in range(len(seeds))] |
|
+ [f"Generated_{i}" for i in range(len(samples))], |
|
}, |
|
) |
|
|
|
result_df = pd.DataFrame(result) |
|
obj = mols2grid.display( |
|
result_df, |
|
tooltip=list(result.keys()), |
|
height=1100, |
|
n_cols=n_cols, |
|
name="Results", |
|
size=size, |
|
) |
|
return obj.data |
|
|