import os
import gradio as gr
import numpy as np
import pandas as pd
import periodictable
import plotly.graph_objs as go
import polars as pl
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Element, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
ComputedStructureEntry,
GibbsComputedStructureEntry,
)
HF_TOKEN = os.environ.get("HF_TOKEN")
subsets = [
"compatible_pbe",
"compatible_pbesol",
"compatible_scan",
]
# polars_dfs = {
# subset: pl.read_parquet(
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
# storage_options={
# "token": HF_TOKEN,
# },
# )
# for subset in subsets
# }
# # Load only the train split of the dataset
subsets_ds = {}
for subset in subsets:
dataset = load_dataset(
"LeMaterial/LeMat-Bulk",
subset,
token=HF_TOKEN,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
],
)
subsets_ds[subset] = dataset["train"]
elements_df = {
k: subset.select_columns("elements").to_pandas() for k, subset in subsets_ds.items()
}
from scipy.sparse import csr_matrix
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
elements_indices = {}
for subset, df in elements_df.items():
print("Processing subset: ", subset)
elements_indices[subset] = np.zeros((len(df), len(all_elements)))
def map_elements(row):
index, xs = row["index"], row["elements"]
for x in xs:
elements_indices[subset][index, all_elements[x]] = 1
df = df.reset_index().apply(map_elements, axis=1)
elements_indices[subset] = csr_matrix(elements_indices[subset])
map_functional = {
"PBE": "compatible_pbe",
"PBESol (No correction scheme)": "compatible_pbesol",
"SCAN (No correction scheme)": "compatible_scan",
}
def create_phase_diagram(
elements,
energy_correction,
plot_style,
functional,
finite_temp,
**kwargs,
):
# Split elements and remove any whitespace
element_list = [el.strip() for el in elements.split("-")]
subset_name = map_functional[functional]
element_list_vector = np.zeros(len(all_elements))
for el in element_list:
element_list_vector[all_elements[el]] = 1
n_elements = elements_indices[subset_name].sum(axis=1)
n_elements_query = elements_indices[subset_name][
:, element_list_vector.astype(bool)
]
if n_elements_query.shape[1] == 0:
indices_with_only_elements = []
else:
indices_with_only_elements = np.where(
n_elements_query.sum(axis=1) == n_elements
)[0]
print(indices_with_only_elements)
entries_df = subsets_ds[subset_name].select(indices_with_only_elements).to_pandas()
entries_df = entries_df[~entries_df["immutable_id"].isna()]
print(entries_df)
# Fetch all entries from the Materials Project database
def get_energy_correction(energy_correction, row):
if energy_correction == "Database specific, or MP2020" and functional == "PBE":
print("applying MP corrections")
return (
row["energy_corrected"] - row["energy"]
if not np.isnan(row["energy_corrected"])
else 0
)
elif energy_correction == "The 110 PBE Method" and functional == "PBE":
print("applying PBE110 corrections")
return row["energy"] * 1.1 - row["energy"]
elif map_functional[functional] != "pbe":
print("not applying any corrections")
return 0
entries = [
ComputedStructureEntry(
Structure(
[x.tolist() for x in row["lattice_vectors"].tolist()],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
),
energy=row["energy"],
correction=get_energy_correction(energy_correction, row),
entry_id=row["immutable_id"],
parameters={"run_type": row["functional"]},
)
for n, row in entries_df.iterrows()
]
# TODO: Fetch elemental entries (they are usually GGA calculations)
# entries.extend([e for e in entries if e.composition.is_element])
if finite_temp:
entries = GibbsComputedStructureEntry.from_entries(entries)
# Build the phase diagram
try:
phase_diagram = PhaseDiagram(entries)
except ValueError as e:
print(e)
return go.Figure().add_annotation(text=str(e))
# Generate plotly figure
if plot_style == "2D":
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
fig = plotter.get_plot()
else:
# For 3D plots, limit to ternary systems
if len(element_list) == 3:
plotter = PDPlotter(
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
)
fig = plotter.get_plot()
else:
return go.Figure().add_annotation(
text="3D plots are only available for ternary systems."
)
# Adjust the maximum energy above hull
# (This is a placeholder as PDPlotter does not support direct filtering)
# Return the figure
return fig
# Define Gradio interface components
elements_input = gr.Textbox(
label="Elements (e.g., 'Li-Fe-O')",
placeholder="Enter elements separated by '-'",
value="Li-Fe-O",
)
# max_e_above_hull_slider = gr.Slider(
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
# )
energy_correction_dropdown = gr.Dropdown(
choices=[
"The 110 PBE Method",
"Database specific, or MP2020",
],
label="Energy correction",
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(
choices=["PBE", "PBESol (No correction scheme)", "SCAN (No correction scheme)"],
label="Functional",
)
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")
warning_message = "⚠️ This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."
warning_message += "
Additionally, we have provided the 110 PBE correction method"
warning_message += " from Rohr et al (2024)."
message = "{}
Generate a phase diagram for a set of elements using LeMat-Bulk data.".format(
warning_message
)
message += """
This web app is powered by Crystal Toolkit, MP Dash Components, and Pymatgen. All tools are developed by the Materials Project. We are grateful for their open-source software packages. This app is intended for data exploration in LeMat-Bulk and is not affiliated with or endorsed by the Materials Project.
CC-BY-4.0 requires proper acknowledgement. If you use materials data with an immutable_id starting with mp-
, please cite the
Materials Project.
If you use materials data with an immutable_id starting with agm-
, please cite
Alexandria, PBE
or
Alexandria PBESol, SCAN.
If you use materials data with an immutable_id starting with oqmd-
, please cite
OQMD.
If you use the Phase Diagram or Crystal Viewer, please acknowledge Crystal Toolkit.