materials_explorer / data_utils.py
Ramlaoui's picture
Use sparse matrices
a10ccb7
import os
import re
import crystal_toolkit.components as ctc
import numpy as np
import periodictable
from dash import dcc, html
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
HF_TOKEN = os.environ.get("HF_TOKEN")
top_k = 500
def get_dataset():
# Load only the train split of the dataset
datasets = []
subsets = [
"compatible_pbe",
"compatible_pbesol",
"compatible_scan",
"non_compatible",
]
for subset in subsets:
dataset = load_dataset(
"LeMaterial/leMat-Bulk",
subset,
token=HF_TOKEN,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
# "energy_corrected", # not yet available in LeMat-Bulk
"immutable_id",
"elements",
"stress_tensor",
"magnetic_moments",
"forces",
# "band_gap_direct", #future release
# "band_gap_indirect", #future release
"dos_ef",
# "charges", #future release
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization",
"entalpic_fingerprint",
],
)
datasets.append(dataset["train"])
return concatenate_datasets(datasets)
display_columns = [
"chemical_formula_descriptive",
"functional",
"immutable_id",
"energy",
]
display_names = {
"chemical_formula_descriptive": "Formula",
"functional": "Functional",
"immutable_id": "Material ID",
"energy": "Energy (eV)",
}
# Global shared variables
mapping_table_idx_dataset_idx = {}
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
print("Building formula index")
if empty_data:
return np.zeros((1, 1)), {}
use_dataset = dataset
if index_range is not None:
use_dataset = dataset.select(index_range)
# Preprocessing step to create an index for the dataset
from scipy.sparse import load_npz
if cache_path is not None and os.path.exists(f"{cache_path}/train_df.pkl"):
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
dataset_index = load_npz(f"{cache_path}/dataset_index.npz")
else:
train_df = use_dataset.select_columns(
["species_at_sites", "immutable_id", "functional"]
).to_pandas()
import tqdm
all_elements = {
str(el.symbol): i for i, el in enumerate(periodictable.elements)
} # full element list
dataset_index = np.zeros((len(train_df), len(all_elements)))
for idx, species in tqdm.tqdm(enumerate(train_df["species_at_sites"].values)):
for el in species:
dataset_index[idx, all_elements[el]] += 1
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
dataset_index = (
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
) # Normalize vectors
from scipy.sparse import csr_matrix, save_npz
dataset_index = csr_matrix(dataset_index)
if cache_path is not None:
pickle.dump(train_df, open(f"{cache_path}/train_df.pkl", "wb"))
save_npz(f"{cache_path}/dataset_index.npz", dataset_index)
immutable_id_to_idx = train_df["immutable_id"].to_dict()
del train_df
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
return dataset_index, immutable_id_to_idx
import pickle
from pathlib import Path
# TODO: Just load the index from a file
def build_embeddings_index(empty_data=False):
if empty_data:
return None, {}, {}
features_dict = pickle.load(open("features_dict.pkl", "rb"))
from indexer import FAISSIndex
index = FAISSIndex()
for key in features_dict:
index.index.add(features_dict[key].reshape(1, -1))
idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)}
# index = FAISSIndex.from_store("index.faiss")
return index, features_dict, idx_to_immutable_id
def search_materials(
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
):
n_elements = len(map_periodic_table)
query_vector = np.zeros(n_elements)
if "," in query:
element_list = [el.strip() for el in query.split(",")]
for el in element_list:
query_vector[map_periodic_table[el]] = 1
else:
# Formula
import re
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
for el, numb in matches:
numb = int(numb) if numb else 1
query_vector[map_periodic_table[el]] = numb
similarity = dataset_index.dot(query_vector) / (np.linalg.norm(query_vector))
indices = np.argsort(similarity)[::-1][:top_k]
options = [dataset[int(i)] for i in indices]
mapping_table_idx_dataset_idx.clear()
for i, idx in enumerate(indices):
mapping_table_idx_dataset_idx[int(i)] = int(idx)
return options
def get_properties_table(
row, structure, sga, properties_container_update, container_type="query"
):
properties = {
"Material ID": row["immutable_id"],
"Formula": row["chemical_formula_descriptive"],
"Energy per atom (eV/atom)": round(
row["energy"] / len(row["species_at_sites"]), 3
),
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
"Total Magnetization (μB)": (
round(row["total_magnetization"], 3)
if row["total_magnetization"] is not None
else None
),
"Density (g/cm^3)": round(structure.density, 3),
"Fermi energy level (eV)": (
round(row["dos_ef"], 3) if row["dos_ef"] is not None else None
),
"Crystal system": sga.get_crystal_system(),
"International Spacegroup": sga.get_symmetry_dataset().international,
"Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
"Stress tensor (kB)": np.round(row["stress_tensor"], 3),
"Forces on atoms (eV/A)": np.round(row["forces"], 3),
# "Bader charges (e-)": np.round(row["charges"], 3), # future release
"DFT Functional": row["functional"],
"Entalpic fingerprint": row["entalpic_fingerprint"],
}
style = {
"padding": "10px",
"borderBottom": "1px solid #ddd",
}
if container_type == "query":
properties_container_update[0] = properties
else:
properties_container_update[1] = properties
# if (type(value) in [str, float]) and (
# properties_container_update[0][key] == properties_container_update[1][key]
# ):
# style["backgroundColor"] = "#e6f7ff"
# Format properties as an HTML table
properties_html = html.Table(
[
html.Tbody(
[
html.Tr(
[
html.Th(
key,
style={
"padding": "10px",
"verticalAlign": "middle",
},
),
html.Td(
str(value),
style=style,
),
],
)
for key, value in properties.items()
],
)
],
style={
"width": "100%",
"borderCollapse": "collapse",
"fontFamily": "'Arial', sans-serif",
"fontSize": "14px",
"color": "#333333",
},
)
return properties_html
def get_crystal_plot(structure):
sga = SpacegroupAnalyzer(structure)
# Create the StructureMoleculeComponent
structure_component = ctc.StructureMoleculeComponent(structure)
return structure_component.layout(), sga