import os import re import crystal_toolkit.components as ctc import dash import dash_mp_components as dmp import numpy as np import periodictable from crystal_toolkit.settings import SETTINGS from dash import dcc, html from dash.dependencies import Input, Output, State from datasets import load_dataset from pymatgen.core import Structure from pymatgen.ext.matproj import MPRester HF_TOKEN = os.environ.get("HF_TOKEN") top_k = 500 # Load only the train split of the dataset dataset = load_dataset( "LeMaterial/leDataset", token=HF_TOKEN, split="train", columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", "energy_corrected", "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", "band_gap_direct", "band_gap_indirect", "dos_ef", "charges", "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization", ], ) display_columns = [ "chemical_formula_descriptive", "functional", "immutable_id", "energy", ] display_names = { "chemical_formula_descriptive": "Formula", "functional": "Functional", "immutable_id": "Material ID", "energy": "Energy (eV)", } mapping_table_idx_dataset_idx = {} map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} n_elements = len(map_periodic_table) train_df = dataset.to_pandas() pattern = re.compile(r"(?P[A-Z][a-z]?)(?P\d*)") extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) extracted["count"] = extracted["count"].replace("", "1").astype(int) wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting index="level_0", # original row index columns="element", values="count", aggfunc="sum", fill_value=0, ) all_elements = [el.symbol for el in periodictable.elements] # full element list wide_df = wide_df.reindex(columns=all_elements, fill_value=0) dataset_index = wide_df.values dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] dataset_index = ( dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] ) # Normalize vectors # Initialize the Dash app app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) server = app.server # Expose the server for deployment # Define the app layout layout = html.Div( [ html.H1( html.B("Interactive Crystal Viewer"), style={"textAlign": "center", "margin-top": "20px"}, ), html.Div( [ html.Div( id="structure-container", style={ "width": "48%", "display": "inline-block", "verticalAlign": "top", }, ), html.Div( id="properties-container", style={ "width": "48%", "display": "inline-block", "paddingLeft": "4%", "verticalAlign": "top", }, ), ], style={"margin-top": "20px"}, ), html.Div( [ html.Div( [ html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"), html.Div( [ html.Div( [ dmp.MaterialsInput( allowedInputTypes=["elements", "formula"], hidePeriodicTable=False, periodicTableMode="toggle", hideWildcardButton=True, showSubmitButton=True, submitButtonText="Search", type="elements", id="materials-input", ), ], style={ "width": "48%", }, ), ], style={ "display": "flex", "justifyContent": "center", "width": "100%", }, ), ], style={ "width": "100%", "verticalAlign": "top", }, ), ], style={"margin-top": "20px", "margin-bottom": "20px"}, ), html.Div( [ html.Label("Select Material to Display"), # dcc.Dropdown( # id="material-dropdown", # options=[], # Empty options initially # value=None, # ), dash.dash_table.DataTable( id="table", columns=[ ( {"name": display_names[col], "id": col} if col != "energy" else { "name": display_names[col], "id": col, "type": "numeric", "format": {"specifier": ".2f"}, } ) for col in display_columns ], data=[{}], style_table={ "overflowX": "auto", "height": "220px", "overflowY": "auto", }, style_header={"fontWeight": "bold", "backgroundColor": "lightgrey"}, style_cell={"textAlign": "center"}, style_as_list_view=True, ), ], style={"margin-top": "30px"}, ), # html.Button("Display Material", id="display-button", n_clicks=0), ], style={ "margin-left": "10px", "margin-right": "10px", }, ) def search_materials(query): query_vector = np.zeros(n_elements) if "," in query: element_list = [el.strip() for el in query.split(",")] for el in element_list: query_vector[map_periodic_table[el]] = 1 else: # Formula import re matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) for el, numb in matches: numb = int(numb) if numb else 1 query_vector[map_periodic_table[el]] = numb similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector)) indices = np.argsort(similarity)[::-1][:top_k] options = [dataset[int(i)] for i in indices] mapping_table_idx_dataset_idx.clear() for i, idx in enumerate(indices): mapping_table_idx_dataset_idx[int(i)] = int(idx) return options # Callback to update the table based on search @app.callback( Output("table", "data"), Input("materials-input", "submitButtonClicks"), Input("materials-input", "value"), ) def on_submit_materials_input(n_clicks, query): if n_clicks is None or not query: return [] entries = search_materials(query) return [{col: entry[col] for col in display_columns} for entry in entries] # Callback to display the selected material @app.callback( [ Output("structure-container", "children"), Output("properties-container", "children"), ], # Input("display-button", "n_clicks"), Input("table", "active_cell"), ) def display_material(active_cell): if not active_cell: return "", "" idx_active = active_cell["row"] row = dataset[mapping_table_idx_dataset_idx[idx_active]] structure = Structure( [x for y in row["lattice_vectors"] for x in y], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ) # Create the StructureMoleculeComponent structure_component = ctc.StructureMoleculeComponent(structure) # Extract key properties properties = { "Material ID": row["immutable_id"], "Formula": row["chemical_formula_descriptive"], "Energy per atom (eV/atom)": row["energy"] / len(row["species_at_sites"]), "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], "Total Magnetization (μB/f.u.)": row["total_magnetization"], } # Format properties as an HTML table properties_html = html.Table( [ html.Tbody( [ html.Tr([html.Th(key), html.Td(str(value))]) for key, value in properties.items() ] ) ], style={ "border": "1px solid black", "width": "100%", "borderCollapse": "collapse", }, ) return structure_component.layout(), properties_html # Register crystal toolkit with the app ctc.register_crystal_toolkit(app, layout) if __name__ == "__main__": app.run_server(debug=True, port=7860, host="0.0.0.0")