File size: 1,285 Bytes
895a807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779b3f2
f1e36b5
779b3f2
f1e36b5
895a807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779b3f2
f1e36b5
779b3f2
f1e36b5
779b3f2
 
895a807
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import json
import logging
import os
from collections import defaultdict
from typing import Dict, List, Tuple

import mols2grid
import pandas as pd
from rdkit import Chem
from terminator.selfies import decoder

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def draw_grid_generate(
    seeds: List[str],
    scaffolds: List[str],
    samples: List[str],
    n_cols: int = 5,
    size=(140, 200),
) -> str:
    """
    Uses mols2grid to draw a HTML grid for the generated molecules

    Args:
        samples: The generated samples.
        n_cols: Number of columns in grid. Defaults to 5.
        size: Size of molecule in grid. Defaults to (140, 200).

    Returns:
        HTML to display
    """

    result = defaultdict(list)
    result.update(
        {
            "SMILES": seeds + scaffolds + samples,
            "Name": [f"Seed_{i}" for i in range(len(seeds))]
            + [f"Scaffold_{i}" for i in range(len(scaffolds))]
            + [f"Generated_{i}" for i in range(len(samples))],
        },
    )

    result_df = pd.DataFrame(result)
    obj = mols2grid.display(
        result_df,
        tooltip=list(result.keys()),
        height=1100,
        n_cols=n_cols,
        name="Results",
        size=size,
    )
    return obj.data