import os import re import crystal_toolkit.components as ctc import dash import dash_mp_components as dmp import numpy as np import pandas as pd import periodictable from crystal_toolkit.settings import SETTINGS from dash import dcc, html from dash.dependencies import Input, Output, State from dash_breakpoints import WindowBreakpoints from datasets import concatenate_datasets, load_dataset from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer from pymatgen.core import Structure HF_TOKEN = os.environ.get("HF_TOKEN") top_k = 500 subsets = ["compatible_pbe", "compatible_pbesol", "compatible_scan", "non_compatible"] # Load only the train split of the dataset datasets = [] for subset in subsets: dataset = load_dataset( "LeMaterial/leMat-Bulk", subset, token=HF_TOKEN, columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", # "energy_corrected", # not yet available in LeMat-Bulk "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", # "band_gap_direct", #future release # "band_gap_indirect", #future release "dos_ef", # "charges", #future release "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization", ], ) datasets.append(dataset) display_columns = [ "chemical_formula_descriptive", "functional", "immutable_id", "energy", ] display_names = { "chemical_formula_descriptive": "Formula", "functional": "Functional", "immutable_id": "Material ID", "energy": "Energy (eV)", } mapping_table_idx_dataset_idx = {} map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} n_elements = len(map_periodic_table) # Preprocessing step to create an index for the dataset # df = pd.concat([x.to_pandas() for x in datasets]) dataset = concatenate_datasets(datasets) train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas() pattern = re.compile(r"(?P[A-Z][a-z]?)(?P\d*)") extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) extracted["count"] = extracted["count"].replace("", "1").astype(int) wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting index="level_0", # original row index columns="element", values="count", aggfunc="sum", fill_value=0, ) all_elements = [el.symbol for el in periodictable.elements] # full element list wide_df = wide_df.reindex(columns=all_elements, fill_value=0) dataset_index = wide_df.values dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] dataset_index = ( dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] ) # Normalize vectors del train_df, extracted, wide_df # Initialize the Dash app app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) server = app.server # Expose the server for deployment # Define the app layout layout = html.Div( [ WindowBreakpoints( id="breakpoints", widthBreakpointThresholdsPx=[800, 1200], widthBreakpointNames=["sm", "md", "lg"], ), html.H1( html.B("Interactive Crystal Viewer"), style={"textAlign": "center", "margin-top": "20px"}, ), html.Div( [ html.Div( [ html.Div( "Search a material to display its structure and properties", style={"textAlign": "center"}, ), ], id="structure-container", style={ "width": "44%", "verticalAlign": "top", "boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)", "borderRadius": "10px", "backgroundColor": "#f9f9f9", "padding": "20px", "textAlign": "center", "display": "flex", "justifyContent": "center", "alignItems": "center", }, ), html.Div( id="properties-container", style={ "width": "55%", "paddingLeft": "4%", "verticalAlign": "top", "boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)", "borderRadius": "10px", "backgroundColor": "#f9f9f9", "padding": "20px", "overflow": "auto", "maxHeight": "600px", "display": "flex", "justifyContent": "center", "wordWrap": "break-word", }, children=[ html.Div( "Properties will be displayed here", style={"textAlign": "center"}, ), ], ), ], style={ "marginTop": "20px", "display": "flex", "justifyContent": "space-between", # Ensure the two sections are responsive "flexWrap": "wrap", }, ), html.Div( [ html.Div( [ html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"), html.Div( [ html.Div( [ dmp.MaterialsInput( allowedInputTypes=["elements", "formula"], hidePeriodicTable=False, periodicTableMode="toggle", hideWildcardButton=True, showSubmitButton=True, submitButtonText="Search", type="elements", id="materials-input", ), ], id="materials-input-container", style={ "width": "100%", }, ), ], style={ "display": "flex", "justifyContent": "center", "width": "100%", }, ), ], style={ "width": "48%", "verticalAlign": "top", }, ), html.Div( [ html.Label( "Select a row to display the material's structure and properties", style={"margin-bottom": "20px"}, ), # dcc.Dropdown( # id="material-dropdown", # options=[], # Empty options initially # value=None, # ), dash.dash_table.DataTable( id="table", columns=[ ( {"name": display_names[col], "id": col} if col != "energy" else { "name": display_names[col], "id": col, "type": "numeric", "format": {"specifier": ".2f"}, } ) for col in display_columns ], data=[{}], style_cell={ "fontFamily": "Arial", "padding": "10px", "border": "1px solid #ddd", # Subtle border for elegance "textAlign": "left", "fontSize": "14px", }, style_header={ "backgroundColor": "#f5f5f5", # Light grey header "fontWeight": "bold", "textAlign": "left", "borderBottom": "2px solid #ddd", }, style_data={ "backgroundColor": "#ffffff", "color": "#333333", "borderBottom": "1px solid #ddd", }, style_data_conditional=[ { "if": {"state": "active"}, "backgroundColor": "#e6f7ff", "border": "1px solid #1890ff", }, ], style_table={ "maxHeight": "400px", "overflowX": "auto", "overflowY": "auto", }, style_as_list_view=True, row_selectable="single", selected_rows=[], ), ], style={ "width": "48%", # "maxWidth": "800px", "margin": "0 auto", "padding": "20px", "backgroundColor": "#ffffff", "borderRadius": "10px", "boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)", }, ), ], style={ "margin-top": "20px", "margin-bottom": "20px", "display": "flex", "flexDirection": "row", "alignItems": "center", }, ), # html.Button("Display Material", id="display-button", n_clicks=0), ], style={ "margin-left": "10px", "margin-right": "10px", }, ) def search_materials(query): query_vector = np.zeros(n_elements) if "," in query: element_list = [el.strip() for el in query.split(",")] for el in element_list: query_vector[map_periodic_table[el]] = 1 else: # Formula import re matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) for el, numb in matches: numb = int(numb) if numb else 1 query_vector[map_periodic_table[el]] = numb similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector)) indices = np.argsort(similarity)[::-1][:top_k] options = [dataset[int(i)] for i in indices] mapping_table_idx_dataset_idx.clear() for i, idx in enumerate(indices): mapping_table_idx_dataset_idx[int(i)] = int(idx) return options # Callback to update the table based on search @app.callback( Output("table", "data"), Input("materials-input", "submitButtonClicks"), Input("materials-input", "value"), ) def on_submit_materials_input(n_clicks, query): if n_clicks is None or not query: return [] entries = search_materials(query) return [{col: entry[col] for col in display_columns} for entry in entries] # Callback to display the selected material @app.callback( [ Output("structure-container", "children"), Output("properties-container", "children"), ], # Input("display-button", "n_clicks"), Input("table", "active_cell"), Input("table", "derived_virtual_selected_rows"), ) def display_material(active_cell, selected_rows): if not active_cell and not selected_rows: return ( html.Div( "Search a material to display its structure and properties", style={"textAlign": "center"}, ), html.Div( "Properties will be displayed here", style={"textAlign": "center"}, ), ) if len(selected_rows) > 0: idx_active = selected_rows[0] else: idx_active = active_cell["row"] row = dataset[mapping_table_idx_dataset_idx[idx_active]] structure = Structure( [x for y in row["lattice_vectors"] for x in y], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ) if row["magnetic_moments"]: structure.add_site_property("magmom", row["magnetic_moments"]) sga = SpacegroupAnalyzer(structure) # Create the StructureMoleculeComponent structure_component = ctc.StructureMoleculeComponent(structure) # Extract key properties properties = { "Material ID": row["immutable_id"], "Formula": row["chemical_formula_descriptive"], "Energy per atom (eV/atom)": round( row["energy"] / len(row["species_at_sites"]), 3 ), # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release "Total Magnetization (μB)": row["total_magnetization"], "Density (g/cm^3)": round(structure.density, 3), "Fermi energy level (eV)": row["dos_ef"], "Crystal system": sga.get_crystal_system(), "International Spacegroup": sga.get_symmetry_dataset().international, "Magnetic moments (μB/f.u.)": np.round(row["magnetic_moments"], 3), "Stress tensor (kB)": row["stress_tensor"], "Forces on atoms (eV/A)": np.round(row["forces"], 3), # "Bader charges (e-)": np.round(row["charges"], 3), # future release "DFT Functional": row["functional"], } # Format properties as an HTML table properties_html = html.Table( [ html.Tbody( [ html.Tr( [ html.Th( key, style={ "padding": "10px", "verticalAlign": "middle", }, ), html.Td( str(value), style={ "padding": "10px", "borderBottom": "1px solid #ddd", }, ), ], ) for key, value in properties.items() ], ) ], style={ "width": "100%", "borderCollapse": "collapse", "fontFamily": "'Arial', sans-serif", "fontSize": "14px", "color": "#333333", }, ) return structure_component.layout(), properties_html @app.callback( Output("materials-input-container", "children"), Input("breakpoints", "widthBreakpoint"), State("breakpoints", "width"), ) def update_materials_input_layout(breakpoint_name, width): if breakpoint_name in ["lg", "md"]: # Default layout if no page size is detected return dmp.MaterialsInput( allowedInputTypes=["elements", "formula"], hidePeriodicTable=False, periodicTableMode="toggle", hideWildcardButton=True, showSubmitButton=True, submitButtonText="Search", type="elements", id="materials-input", ) elif breakpoint_name == "sm": return dmp.MaterialsInput( allowedInputTypes=["elements", "formula"], hidePeriodicTable=True, periodicTableMode="none", hideWildcardButton=False, showSubmitButton=False, # submitButtonText="Search", type="elements", id="materials-input", ) # Register crystal toolkit with the app ctc.register_crystal_toolkit(app, layout) if __name__ == "__main__": app.run_server(debug=True, port=7860, host="0.0.0.0")