Spaces:
Running
Running
More responsive
Browse files- app.py +59 -339
- assets/styles.css +124 -0
- components.py +184 -0
- data_utils.py +261 -0
app.py
CHANGED
@@ -1,112 +1,73 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
|
4 |
import crystal_toolkit.components as ctc
|
5 |
import dash
|
6 |
import dash_mp_components as dmp
|
7 |
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
import periodictable
|
10 |
from crystal_toolkit.settings import SETTINGS
|
11 |
from dash import dcc, html
|
12 |
from dash.dependencies import Input, Output, State
|
13 |
from dash_breakpoints import WindowBreakpoints
|
14 |
-
from datasets import concatenate_datasets, load_dataset
|
15 |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
|
16 |
from pymatgen.core import Structure
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
|
|
24 |
|
25 |
-
|
26 |
-
for subset in subsets:
|
27 |
-
dataset = load_dataset(
|
28 |
-
"LeMaterial/leMat-Bulk",
|
29 |
-
subset,
|
30 |
-
token=HF_TOKEN,
|
31 |
-
columns=[
|
32 |
-
"lattice_vectors",
|
33 |
-
"species_at_sites",
|
34 |
-
"cartesian_site_positions",
|
35 |
-
"energy",
|
36 |
-
# "energy_corrected", # not yet available in LeMat-Bulk
|
37 |
-
"immutable_id",
|
38 |
-
"elements",
|
39 |
-
"functional",
|
40 |
-
"stress_tensor",
|
41 |
-
"magnetic_moments",
|
42 |
-
"forces",
|
43 |
-
# "band_gap_direct", #future release
|
44 |
-
# "band_gap_indirect", #future release
|
45 |
-
"dos_ef",
|
46 |
-
# "charges", #future release
|
47 |
-
"functional",
|
48 |
-
"chemical_formula_reduced",
|
49 |
-
"chemical_formula_descriptive",
|
50 |
-
"total_magnetization",
|
51 |
-
"entalpic_fingerprint"
|
52 |
-
],
|
53 |
-
)
|
54 |
-
datasets.append(dataset["train"])
|
55 |
|
56 |
-
|
57 |
"chemical_formula_descriptive",
|
58 |
"functional",
|
59 |
"immutable_id",
|
60 |
"energy",
|
61 |
]
|
62 |
-
|
63 |
"chemical_formula_descriptive": "Formula",
|
64 |
"functional": "Functional",
|
65 |
"immutable_id": "Material ID",
|
66 |
"energy": "Energy (eV)",
|
67 |
}
|
68 |
|
|
|
69 |
mapping_table_idx_dataset_idx = {}
|
|
|
70 |
|
71 |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
|
72 |
-
n_elements = len(map_periodic_table)
|
73 |
|
74 |
-
#
|
75 |
-
|
76 |
-
dataset =
|
77 |
-
train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas()
|
78 |
-
|
79 |
-
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
|
80 |
-
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
|
81 |
-
extracted["count"] = extracted["count"].replace("", "1").astype(int)
|
82 |
-
|
83 |
-
wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
|
84 |
-
index="level_0", # original row index
|
85 |
-
columns="element",
|
86 |
-
values="count",
|
87 |
-
aggfunc="sum",
|
88 |
-
fill_value=0,
|
89 |
)
|
90 |
|
91 |
-
all_elements = [el.symbol for el in periodictable.elements] # full element list
|
92 |
-
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
|
93 |
-
|
94 |
-
|
95 |
-
dataset_index = wide_df.values
|
96 |
-
|
97 |
-
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
98 |
-
dataset_index = (
|
99 |
-
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
100 |
-
) # Normalize vectors
|
101 |
-
|
102 |
-
del train_df, extracted, wide_df
|
103 |
-
|
104 |
# Initialize the Dash app
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
server = app.server # Expose the server for deployment
|
107 |
|
108 |
# Define the app layout
|
109 |
-
layout = html.Div(
|
110 |
[
|
111 |
WindowBreakpoints(
|
112 |
id="breakpoints",
|
@@ -119,178 +80,26 @@ layout = html.Div(
|
|
119 |
),
|
120 |
html.Div(
|
121 |
[
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
),
|
128 |
-
],
|
129 |
-
id="structure-container",
|
130 |
-
style={
|
131 |
-
"width": "44%",
|
132 |
-
"verticalAlign": "top",
|
133 |
-
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
|
134 |
-
"borderRadius": "10px",
|
135 |
-
"backgroundColor": "#f9f9f9",
|
136 |
-
"padding": "20px",
|
137 |
-
"textAlign": "center",
|
138 |
-
"display": "flex",
|
139 |
-
"justifyContent": "center",
|
140 |
-
"alignItems": "center",
|
141 |
-
},
|
142 |
-
),
|
143 |
-
html.Div(
|
144 |
-
id="properties-container",
|
145 |
-
style={
|
146 |
-
"width": "55%",
|
147 |
-
"paddingLeft": "4%",
|
148 |
-
"verticalAlign": "top",
|
149 |
-
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
|
150 |
-
"borderRadius": "10px",
|
151 |
-
"backgroundColor": "#f9f9f9",
|
152 |
-
"padding": "20px",
|
153 |
-
"overflow": "auto",
|
154 |
-
"maxHeight": "600px",
|
155 |
-
"display": "flex",
|
156 |
-
"justifyContent": "center",
|
157 |
-
"wordWrap": "break-word",
|
158 |
-
},
|
159 |
-
children=[
|
160 |
-
html.Div(
|
161 |
-
"Properties will be displayed here",
|
162 |
-
style={"textAlign": "center"},
|
163 |
-
),
|
164 |
-
],
|
165 |
-
),
|
166 |
],
|
167 |
-
|
168 |
-
"marginTop": "20px",
|
169 |
-
"display": "flex",
|
170 |
-
"justifyContent": "space-between", # Ensure the two sections are responsive
|
171 |
-
"flexWrap": "wrap",
|
172 |
-
},
|
173 |
),
|
174 |
html.Div(
|
175 |
[
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
[
|
183 |
-
dmp.MaterialsInput(
|
184 |
-
allowedInputTypes=["elements", "formula"],
|
185 |
-
hidePeriodicTable=False,
|
186 |
-
periodicTableMode="toggle",
|
187 |
-
hideWildcardButton=True,
|
188 |
-
showSubmitButton=True,
|
189 |
-
submitButtonText="Search",
|
190 |
-
type="elements",
|
191 |
-
id="materials-input",
|
192 |
-
),
|
193 |
-
],
|
194 |
-
id="materials-input-container",
|
195 |
-
style={
|
196 |
-
"width": "100%",
|
197 |
-
},
|
198 |
-
),
|
199 |
-
],
|
200 |
-
style={
|
201 |
-
"display": "flex",
|
202 |
-
"justifyContent": "center",
|
203 |
-
"width": "100%",
|
204 |
-
},
|
205 |
-
),
|
206 |
-
],
|
207 |
-
style={
|
208 |
-
"width": "48%",
|
209 |
-
"verticalAlign": "top",
|
210 |
-
},
|
211 |
-
),
|
212 |
-
html.Div(
|
213 |
-
[
|
214 |
-
html.Label(
|
215 |
-
"Select a row to display the material's structure and properties",
|
216 |
-
style={"margin-bottom": "20px"},
|
217 |
-
),
|
218 |
-
# dcc.Dropdown(
|
219 |
-
# id="material-dropdown",
|
220 |
-
# options=[], # Empty options initially
|
221 |
-
# value=None,
|
222 |
-
# ),
|
223 |
-
dash.dash_table.DataTable(
|
224 |
-
id="table",
|
225 |
-
columns=[
|
226 |
-
(
|
227 |
-
{"name": display_names[col], "id": col}
|
228 |
-
if col != "energy"
|
229 |
-
else {
|
230 |
-
"name": display_names[col],
|
231 |
-
"id": col,
|
232 |
-
"type": "numeric",
|
233 |
-
"format": {"specifier": ".2f"},
|
234 |
-
}
|
235 |
-
)
|
236 |
-
for col in display_columns
|
237 |
-
],
|
238 |
-
data=[{}],
|
239 |
-
style_cell={
|
240 |
-
"fontFamily": "Arial",
|
241 |
-
"padding": "10px",
|
242 |
-
"border": "1px solid #ddd", # Subtle border for elegance
|
243 |
-
"textAlign": "left",
|
244 |
-
"fontSize": "14px",
|
245 |
-
},
|
246 |
-
style_header={
|
247 |
-
"backgroundColor": "#f5f5f5", # Light grey header
|
248 |
-
"fontWeight": "bold",
|
249 |
-
"textAlign": "left",
|
250 |
-
"borderBottom": "2px solid #ddd",
|
251 |
-
},
|
252 |
-
style_data={
|
253 |
-
"backgroundColor": "#ffffff",
|
254 |
-
"color": "#333333",
|
255 |
-
"borderBottom": "1px solid #ddd",
|
256 |
-
},
|
257 |
-
style_data_conditional=[
|
258 |
-
{
|
259 |
-
"if": {"state": "active"},
|
260 |
-
"backgroundColor": "#e6f7ff",
|
261 |
-
"border": "1px solid #1890ff",
|
262 |
-
},
|
263 |
-
],
|
264 |
-
style_table={
|
265 |
-
"maxHeight": "400px",
|
266 |
-
"overflowX": "auto",
|
267 |
-
"overflowY": "auto",
|
268 |
-
},
|
269 |
-
style_as_list_view=True,
|
270 |
-
row_selectable="single",
|
271 |
-
selected_rows=[],
|
272 |
-
),
|
273 |
-
],
|
274 |
-
style={
|
275 |
-
"width": "48%",
|
276 |
-
# "maxWidth": "800px",
|
277 |
-
"margin": "0 auto",
|
278 |
-
"padding": "20px",
|
279 |
-
"backgroundColor": "#ffffff",
|
280 |
-
"borderRadius": "10px",
|
281 |
-
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
|
282 |
-
},
|
283 |
),
|
284 |
],
|
285 |
-
|
286 |
-
"margin-top": "20px",
|
287 |
-
"margin-bottom": "20px",
|
288 |
-
"display": "flex",
|
289 |
-
"flexDirection": "row",
|
290 |
-
"alignItems": "center",
|
291 |
-
},
|
292 |
),
|
293 |
-
# acknowledgements to mp dash components and crystal toolkit
|
294 |
html.Footer(
|
295 |
[
|
296 |
html.P(
|
@@ -308,16 +117,6 @@ layout = html.Div(
|
|
308 |
style={"textAlign": "center"},
|
309 |
)
|
310 |
],
|
311 |
-
style={
|
312 |
-
"display": "flex",
|
313 |
-
"justifyContent": "center",
|
314 |
-
"alignItems": "center",
|
315 |
-
"flexWrap": "wrap",
|
316 |
-
"padding": "1rem 0",
|
317 |
-
"backgroundColor": "#f1f1f1", # Optional: light gray footer background
|
318 |
-
"borderTop": "1px solid #ddd", # Optional: subtle border at the top
|
319 |
-
"width": "100%",
|
320 |
-
},
|
321 |
),
|
322 |
],
|
323 |
style={
|
@@ -327,34 +126,6 @@ layout = html.Div(
|
|
327 |
)
|
328 |
|
329 |
|
330 |
-
def search_materials(query):
|
331 |
-
query_vector = np.zeros(n_elements)
|
332 |
-
|
333 |
-
if "," in query:
|
334 |
-
element_list = [el.strip() for el in query.split(",")]
|
335 |
-
for el in element_list:
|
336 |
-
query_vector[map_periodic_table[el]] = 1
|
337 |
-
else:
|
338 |
-
# Formula
|
339 |
-
import re
|
340 |
-
|
341 |
-
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
|
342 |
-
for el, numb in matches:
|
343 |
-
numb = int(numb) if numb else 1
|
344 |
-
query_vector[map_periodic_table[el]] = numb
|
345 |
-
|
346 |
-
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
|
347 |
-
indices = np.argsort(similarity)[::-1][:top_k]
|
348 |
-
|
349 |
-
options = [dataset[int(i)] for i in indices]
|
350 |
-
|
351 |
-
mapping_table_idx_dataset_idx.clear()
|
352 |
-
for i, idx in enumerate(indices):
|
353 |
-
mapping_table_idx_dataset_idx[int(i)] = int(idx)
|
354 |
-
|
355 |
-
return options
|
356 |
-
|
357 |
-
|
358 |
# Callback to update the table based on search
|
359 |
@app.callback(
|
360 |
Output("table", "data"),
|
@@ -365,9 +136,11 @@ def on_submit_materials_input(n_clicks, query):
|
|
365 |
if n_clicks is None or not query:
|
366 |
return []
|
367 |
|
368 |
-
entries = search_materials(
|
|
|
|
|
369 |
|
370 |
-
return [{col: entry[col] for col in
|
371 |
|
372 |
|
373 |
# Callback to display the selected material
|
@@ -376,7 +149,6 @@ def on_submit_materials_input(n_clicks, query):
|
|
376 |
Output("structure-container", "children"),
|
377 |
Output("properties-container", "children"),
|
378 |
],
|
379 |
-
# Input("display-button", "n_clicks"),
|
380 |
Input("table", "active_cell"),
|
381 |
Input("table", "derived_virtual_selected_rows"),
|
382 |
)
|
@@ -408,69 +180,17 @@ def display_material(active_cell, selected_rows):
|
|
408 |
if row["magnetic_moments"]:
|
409 |
structure.add_site_property("magmom", row["magnetic_moments"])
|
410 |
|
411 |
-
sga =
|
412 |
-
|
413 |
-
# Create the StructureMoleculeComponent
|
414 |
-
structure_component = ctc.StructureMoleculeComponent(structure)
|
415 |
|
416 |
# Extract key properties
|
417 |
-
|
418 |
-
|
419 |
-
"Formula": row["chemical_formula_descriptive"],
|
420 |
-
"Energy per atom (eV/atom)": round(
|
421 |
-
row["energy"] / len(row["species_at_sites"]), 3
|
422 |
-
),
|
423 |
-
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
|
424 |
-
"Total Magnetization (μB)": round(row["total_magnetization"], 3) if row['total_magnetization'] is not None else None,
|
425 |
-
"Density (g/cm^3)": round(structure.density, 3),
|
426 |
-
"Fermi energy level (eV)": round(row["dos_ef"],3) if row['dos_ef'] is not None else None,
|
427 |
-
"Crystal system": sga.get_crystal_system(),
|
428 |
-
"International Spacegroup": sga.get_symmetry_dataset().international,
|
429 |
-
"Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
|
430 |
-
"Stress tensor (kB)": np.round(row["stress_tensor"], 3),
|
431 |
-
"Forces on atoms (eV/A)": np.round(row["forces"], 3),
|
432 |
-
# "Bader charges (e-)": np.round(row["charges"], 3), # future release
|
433 |
-
"DFT Functional": row["functional"],
|
434 |
-
"Entalpic fingerprint": row['entalpic_fingerprint'],
|
435 |
-
}
|
436 |
-
|
437 |
-
# Format properties as an HTML table
|
438 |
-
properties_html = html.Table(
|
439 |
-
[
|
440 |
-
html.Tbody(
|
441 |
-
[
|
442 |
-
html.Tr(
|
443 |
-
[
|
444 |
-
html.Th(
|
445 |
-
key,
|
446 |
-
style={
|
447 |
-
"padding": "10px",
|
448 |
-
"verticalAlign": "middle",
|
449 |
-
},
|
450 |
-
),
|
451 |
-
html.Td(
|
452 |
-
str(value),
|
453 |
-
style={
|
454 |
-
"padding": "10px",
|
455 |
-
"borderBottom": "1px solid #ddd",
|
456 |
-
},
|
457 |
-
),
|
458 |
-
],
|
459 |
-
)
|
460 |
-
for key, value in properties.items()
|
461 |
-
],
|
462 |
-
)
|
463 |
-
],
|
464 |
-
style={
|
465 |
-
"width": "100%",
|
466 |
-
"borderCollapse": "collapse",
|
467 |
-
"fontFamily": "'Arial', sans-serif",
|
468 |
-
"fontSize": "14px",
|
469 |
-
"color": "#333333",
|
470 |
-
},
|
471 |
)
|
472 |
|
473 |
-
return
|
|
|
|
|
|
|
474 |
|
475 |
|
476 |
@app.callback(
|
@@ -505,7 +225,7 @@ def update_materials_input_layout(breakpoint_name, width):
|
|
505 |
|
506 |
|
507 |
# Register crystal toolkit with the app
|
508 |
-
ctc.register_crystal_toolkit(app, layout)
|
509 |
|
510 |
if __name__ == "__main__":
|
511 |
app.run_server(debug=True, port=7860, host="0.0.0.0")
|
|
|
|
|
|
|
|
|
1 |
import crystal_toolkit.components as ctc
|
2 |
import dash
|
3 |
import dash_mp_components as dmp
|
4 |
import numpy as np
|
|
|
5 |
import periodictable
|
6 |
from crystal_toolkit.settings import SETTINGS
|
7 |
from dash import dcc, html
|
8 |
from dash.dependencies import Input, Output, State
|
9 |
from dash_breakpoints import WindowBreakpoints
|
|
|
10 |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
|
11 |
from pymatgen.core import Structure
|
12 |
|
13 |
+
from components import (
|
14 |
+
get_display_table,
|
15 |
+
get_dropdown,
|
16 |
+
get_materials_display,
|
17 |
+
get_periodic_table,
|
18 |
+
get_upload_div,
|
19 |
+
)
|
20 |
+
from data_utils import (
|
21 |
+
build_embeddings_index,
|
22 |
+
build_formula_index,
|
23 |
+
get_crystal_plot,
|
24 |
+
get_dataset,
|
25 |
+
get_properties_table,
|
26 |
+
search_materials,
|
27 |
+
)
|
28 |
|
29 |
+
EMPTY_DATA = False
|
30 |
+
CACHE_PATH = None
|
31 |
|
32 |
+
dataset = get_dataset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
display_columns_query = [
|
35 |
"chemical_formula_descriptive",
|
36 |
"functional",
|
37 |
"immutable_id",
|
38 |
"energy",
|
39 |
]
|
40 |
+
display_names_query = {
|
41 |
"chemical_formula_descriptive": "Formula",
|
42 |
"functional": "Functional",
|
43 |
"immutable_id": "Material ID",
|
44 |
"energy": "Energy (eV)",
|
45 |
}
|
46 |
|
47 |
+
|
48 |
mapping_table_idx_dataset_idx = {}
|
49 |
+
available_similar_materials = []
|
50 |
|
51 |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
|
|
|
52 |
|
53 |
+
# dataset_index, immutable_id_to_idx = build_formula_index(dataset, cache_path=None)
|
54 |
+
dataset_index, immutable_id_to_idx = build_formula_index(
|
55 |
+
dataset, cache_path=CACHE_PATH, empty_data=EMPTY_DATA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
# Initialize the Dash app
|
59 |
+
external_stylesheets = [
|
60 |
+
"/assets/styles.css",
|
61 |
+
]
|
62 |
+
|
63 |
+
app = dash.Dash(
|
64 |
+
__name__,
|
65 |
+
external_stylesheets=external_stylesheets,
|
66 |
+
)
|
67 |
server = app.server # Expose the server for deployment
|
68 |
|
69 |
# Define the app layout
|
70 |
+
app.layout = html.Div(
|
71 |
[
|
72 |
WindowBreakpoints(
|
73 |
id="breakpoints",
|
|
|
80 |
),
|
81 |
html.Div(
|
82 |
[
|
83 |
+
get_materials_display(
|
84 |
+
"",
|
85 |
+
"Structure will be displayed here",
|
86 |
+
"Properties will be displayed here",
|
87 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
],
|
89 |
+
className="container-row",
|
|
|
|
|
|
|
|
|
|
|
90 |
),
|
91 |
html.Div(
|
92 |
[
|
93 |
+
get_periodic_table("materials-input", {}),
|
94 |
+
get_display_table(
|
95 |
+
"table",
|
96 |
+
display_names_query,
|
97 |
+
display_columns_query,
|
98 |
+
"Select a row to display the material's structure and properties",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
),
|
100 |
],
|
101 |
+
className="container-row-periodic",
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
),
|
|
|
103 |
html.Footer(
|
104 |
[
|
105 |
html.P(
|
|
|
117 |
style={"textAlign": "center"},
|
118 |
)
|
119 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
),
|
121 |
],
|
122 |
style={
|
|
|
126 |
)
|
127 |
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
# Callback to update the table based on search
|
130 |
@app.callback(
|
131 |
Output("table", "data"),
|
|
|
136 |
if n_clicks is None or not query:
|
137 |
return []
|
138 |
|
139 |
+
entries = search_materials(
|
140 |
+
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
|
141 |
+
)
|
142 |
|
143 |
+
return [{col: entry[col] for col in display_columns_query} for entry in entries]
|
144 |
|
145 |
|
146 |
# Callback to display the selected material
|
|
|
149 |
Output("structure-container", "children"),
|
150 |
Output("properties-container", "children"),
|
151 |
],
|
|
|
152 |
Input("table", "active_cell"),
|
153 |
Input("table", "derived_virtual_selected_rows"),
|
154 |
)
|
|
|
180 |
if row["magnetic_moments"]:
|
181 |
structure.add_site_property("magmom", row["magnetic_moments"])
|
182 |
|
183 |
+
structure_layout, sga = get_crystal_plot(structure)
|
|
|
|
|
|
|
184 |
|
185 |
# Extract key properties
|
186 |
+
properties_html = get_properties_table(
|
187 |
+
row, structure, sga, [None, None], container_type="results"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
)
|
189 |
|
190 |
+
return (
|
191 |
+
structure_layout,
|
192 |
+
properties_html,
|
193 |
+
)
|
194 |
|
195 |
|
196 |
@app.callback(
|
|
|
225 |
|
226 |
|
227 |
# Register crystal toolkit with the app
|
228 |
+
ctc.register_crystal_toolkit(app, app.layout)
|
229 |
|
230 |
if __name__ == "__main__":
|
231 |
app.run_server(debug=True, port=7860, host="0.0.0.0")
|
assets/styles.css
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
font-size: 24px;
|
3 |
+
font-weight: 700;
|
4 |
+
color: #333;
|
5 |
+
}
|
6 |
+
|
7 |
+
.body {
|
8 |
+
background-color: #4a4a4a;
|
9 |
+
}
|
10 |
+
|
11 |
+
.header-container {
|
12 |
+
display: flex;
|
13 |
+
flex-direction: row;
|
14 |
+
justify-content: space-between;
|
15 |
+
align-items: center;
|
16 |
+
padding: 20px;
|
17 |
+
margin-bottom: 20px;
|
18 |
+
}
|
19 |
+
|
20 |
+
.container {
|
21 |
+
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
|
22 |
+
border-radius: 10px;
|
23 |
+
background-color: rgb(249, 249, 249);
|
24 |
+
padding: 20px;
|
25 |
+
margin-left: 10px;
|
26 |
+
margin-right: 10px;
|
27 |
+
max-height: 600px;
|
28 |
+
justify-content: center;
|
29 |
+
}
|
30 |
+
|
31 |
+
.container-visu {
|
32 |
+
width: 45%;
|
33 |
+
align-items: center;
|
34 |
+
}
|
35 |
+
|
36 |
+
.container-table {
|
37 |
+
width: 50%;
|
38 |
+
align-items: center;
|
39 |
+
overflow: auto;
|
40 |
+
}
|
41 |
+
|
42 |
+
/* remove background in periodical table */
|
43 |
+
.periodic-table {
|
44 |
+
background-color: transparent;
|
45 |
+
box-shadow: none;
|
46 |
+
}
|
47 |
+
|
48 |
+
.container-row {
|
49 |
+
width: 100%;
|
50 |
+
display: flex;
|
51 |
+
flex-direction: row;
|
52 |
+
justify-content: space-between;
|
53 |
+
padding: 10px;
|
54 |
+
margin-bottom: 10px;
|
55 |
+
margin-top: 10px;
|
56 |
+
}
|
57 |
+
|
58 |
+
.container-row-periodic {
|
59 |
+
width: 100%;
|
60 |
+
display: flex;
|
61 |
+
flex-direction: row;
|
62 |
+
align-items: center;
|
63 |
+
}
|
64 |
+
|
65 |
+
.container-col {
|
66 |
+
width: 100%;
|
67 |
+
display: flex;
|
68 |
+
flex-direction: column;
|
69 |
+
justify-content: space-between;
|
70 |
+
padding: 10px;
|
71 |
+
margin-bottom: 10px;
|
72 |
+
margin-top: 10px;
|
73 |
+
}
|
74 |
+
|
75 |
+
body {
|
76 |
+
font-family: "Arial", sans-serif;
|
77 |
+
font-size: 16px;
|
78 |
+
}
|
79 |
+
|
80 |
+
@media (max-width: 800px) {
|
81 |
+
.container {
|
82 |
+
width: 100%;
|
83 |
+
margin: 5px;
|
84 |
+
margin-top: 10px;
|
85 |
+
margin-bottom: 10px;
|
86 |
+
}
|
87 |
+
|
88 |
+
.container-row {
|
89 |
+
flex-direction: column;
|
90 |
+
}
|
91 |
+
|
92 |
+
.container-row-periodic {
|
93 |
+
flex-direction: column;
|
94 |
+
}
|
95 |
+
|
96 |
+
.container-visu {
|
97 |
+
width: 100%;
|
98 |
+
}
|
99 |
+
|
100 |
+
.container-table {
|
101 |
+
width: 100%;
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
@media (max-width: 1000px) and (min-width: 800px) {
|
106 |
+
.container-visu {
|
107 |
+
width: 60%;
|
108 |
+
}
|
109 |
+
|
110 |
+
.container-table {
|
111 |
+
width: 39%;
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
footer {
|
116 |
+
display: flex;
|
117 |
+
justify-content: center;
|
118 |
+
align-items: center;
|
119 |
+
flex-wrap: wrap;
|
120 |
+
margin-top: 40px;
|
121 |
+
background-color: #ffffff;
|
122 |
+
border-top: "1px solid #ddd";
|
123 |
+
width: 100%;
|
124 |
+
}
|
components.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dash
|
2 |
+
import dash_mp_components as dmp
|
3 |
+
from dash import dcc, html
|
4 |
+
|
5 |
+
display_columns = [
|
6 |
+
"chemical_formula_descriptive",
|
7 |
+
"functional",
|
8 |
+
"immutable_id",
|
9 |
+
"energy",
|
10 |
+
]
|
11 |
+
|
12 |
+
display_names = {
|
13 |
+
"chemical_formula_descriptive": "Formula",
|
14 |
+
"functional": "Functional",
|
15 |
+
"immutable_id": "Material ID",
|
16 |
+
"energy": "Energy (eV)",
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def get_periodic_table(id, table_kwargs, **style_kwargs):
|
21 |
+
|
22 |
+
return html.Div(
|
23 |
+
[
|
24 |
+
html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
|
25 |
+
html.Div(
|
26 |
+
[
|
27 |
+
dmp.MaterialsInput(
|
28 |
+
allowedInputTypes=[
|
29 |
+
"elements",
|
30 |
+
"formula",
|
31 |
+
],
|
32 |
+
hidePeriodicTable=False,
|
33 |
+
periodicTableMode="toggle",
|
34 |
+
hideWildcardButton=True,
|
35 |
+
showSubmitButton=True,
|
36 |
+
submitButtonText="Search",
|
37 |
+
type="elements",
|
38 |
+
**table_kwargs,
|
39 |
+
id=id,
|
40 |
+
),
|
41 |
+
],
|
42 |
+
id="materials-input-container",
|
43 |
+
style={
|
44 |
+
"width": "100%",
|
45 |
+
},
|
46 |
+
),
|
47 |
+
],
|
48 |
+
className="container periodic-table",
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def get_dropdown(id, options, **style_kwargs):
|
53 |
+
return dcc.Dropdown(
|
54 |
+
id=id,
|
55 |
+
options=options,
|
56 |
+
placeholder="Embedder",
|
57 |
+
value=None,
|
58 |
+
clearable=False,
|
59 |
+
style=style_kwargs,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def get_upload_div(id, **style_kwargs):
|
64 |
+
return html.Div(
|
65 |
+
[
|
66 |
+
html.H3("Upload a CIF file"),
|
67 |
+
dcc.Upload(
|
68 |
+
id=id,
|
69 |
+
children=html.Div(
|
70 |
+
[
|
71 |
+
"Drag and Drop or ",
|
72 |
+
html.A("Select a CIF file"),
|
73 |
+
]
|
74 |
+
),
|
75 |
+
style={
|
76 |
+
"width": "100%",
|
77 |
+
"height": "60px",
|
78 |
+
"lineHeight": "60px",
|
79 |
+
"borderWidth": "1px",
|
80 |
+
"borderStyle": "dashed",
|
81 |
+
"borderRadius": "5px",
|
82 |
+
"textAlign": "center",
|
83 |
+
"margin": "10px",
|
84 |
+
},
|
85 |
+
multiple=False,
|
86 |
+
),
|
87 |
+
],
|
88 |
+
className="container",
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
def get_display_table(id, display_names, display_columns, text, **style_kwargs):
|
93 |
+
|
94 |
+
return html.Div(
|
95 |
+
[
|
96 |
+
html.Label(
|
97 |
+
text,
|
98 |
+
style={"margin-bottom": "20px"},
|
99 |
+
),
|
100 |
+
dash.dash_table.DataTable(
|
101 |
+
id=id,
|
102 |
+
columns=[
|
103 |
+
(
|
104 |
+
{
|
105 |
+
"name": display_names[col],
|
106 |
+
"id": col,
|
107 |
+
}
|
108 |
+
if col != "energy"
|
109 |
+
else {
|
110 |
+
"name": display_names[col],
|
111 |
+
"id": col,
|
112 |
+
"type": "numeric",
|
113 |
+
"format": {"specifier": ".2f"},
|
114 |
+
}
|
115 |
+
)
|
116 |
+
for col in display_columns
|
117 |
+
],
|
118 |
+
data=[{}],
|
119 |
+
style_cell={
|
120 |
+
"fontFamily": "Arial",
|
121 |
+
"padding": "10px",
|
122 |
+
"border": "1px solid #ddd", # Subtle border for elegance
|
123 |
+
"textAlign": "left",
|
124 |
+
"fontSize": "14px",
|
125 |
+
},
|
126 |
+
style_header={
|
127 |
+
"backgroundColor": "#f5f5f5", # Light grey header
|
128 |
+
"fontWeight": "bold",
|
129 |
+
"textAlign": "left",
|
130 |
+
"borderBottom": "2px solid #ddd",
|
131 |
+
},
|
132 |
+
style_data={
|
133 |
+
"backgroundColor": "#ffffff",
|
134 |
+
"color": "#333333",
|
135 |
+
"borderBottom": "1px solid #ddd",
|
136 |
+
},
|
137 |
+
style_data_conditional=[
|
138 |
+
{
|
139 |
+
"if": {"state": "active"},
|
140 |
+
"backgroundColor": "#e6f7ff",
|
141 |
+
"border": "1px solid #1890ff",
|
142 |
+
},
|
143 |
+
],
|
144 |
+
style_table={
|
145 |
+
"maxHeight": "400px",
|
146 |
+
"overflowX": "auto",
|
147 |
+
"overflowY": "auto",
|
148 |
+
},
|
149 |
+
style_as_list_view=True,
|
150 |
+
row_selectable="single",
|
151 |
+
selected_rows=[],
|
152 |
+
),
|
153 |
+
],
|
154 |
+
className="container",
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def get_materials_display(id, text_materials_div, text_table_div, **style_kwargs):
|
159 |
+
return html.Div(
|
160 |
+
[
|
161 |
+
html.Div(
|
162 |
+
[
|
163 |
+
html.Div(
|
164 |
+
text_materials_div,
|
165 |
+
style={"textAlign": "center"},
|
166 |
+
),
|
167 |
+
],
|
168 |
+
id=f"structure-container{id}",
|
169 |
+
className="container container-visu",
|
170 |
+
),
|
171 |
+
html.Div(
|
172 |
+
id=f"properties-container{id}",
|
173 |
+
className="container container-table",
|
174 |
+
style={"width": "100%"},
|
175 |
+
children=[
|
176 |
+
html.Div(
|
177 |
+
text_table_div,
|
178 |
+
style={"textAlign": "center"},
|
179 |
+
),
|
180 |
+
],
|
181 |
+
),
|
182 |
+
],
|
183 |
+
className="container-row",
|
184 |
+
)
|
data_utils.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import crystal_toolkit.components as ctc
|
5 |
+
import numpy as np
|
6 |
+
import periodictable
|
7 |
+
from dash import dcc, html
|
8 |
+
from datasets import concatenate_datasets, load_dataset
|
9 |
+
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
|
10 |
+
|
11 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
12 |
+
top_k = 500
|
13 |
+
|
14 |
+
|
15 |
+
def get_dataset():
|
16 |
+
# Load only the train split of the dataset
|
17 |
+
datasets = []
|
18 |
+
subsets = [
|
19 |
+
"compatible_pbe",
|
20 |
+
"compatible_pbesol",
|
21 |
+
"compatible_scan",
|
22 |
+
"non_compatible",
|
23 |
+
]
|
24 |
+
|
25 |
+
for subset in subsets:
|
26 |
+
dataset = load_dataset(
|
27 |
+
"LeMaterial/leMat-Bulk",
|
28 |
+
subset,
|
29 |
+
token=HF_TOKEN,
|
30 |
+
columns=[
|
31 |
+
"lattice_vectors",
|
32 |
+
"species_at_sites",
|
33 |
+
"cartesian_site_positions",
|
34 |
+
"energy",
|
35 |
+
# "energy_corrected", # not yet available in LeMat-Bulk
|
36 |
+
"immutable_id",
|
37 |
+
"elements",
|
38 |
+
"functional",
|
39 |
+
"stress_tensor",
|
40 |
+
"magnetic_moments",
|
41 |
+
"forces",
|
42 |
+
# "band_gap_direct", #future release
|
43 |
+
# "band_gap_indirect", #future release
|
44 |
+
"dos_ef",
|
45 |
+
# "charges", #future release
|
46 |
+
"functional",
|
47 |
+
"chemical_formula_reduced",
|
48 |
+
"chemical_formula_descriptive",
|
49 |
+
"total_magnetization",
|
50 |
+
"entalpic_fingerprint",
|
51 |
+
],
|
52 |
+
)
|
53 |
+
datasets.append(dataset["train"])
|
54 |
+
|
55 |
+
return concatenate_datasets(datasets)
|
56 |
+
|
57 |
+
|
58 |
+
display_columns = [
|
59 |
+
"chemical_formula_descriptive",
|
60 |
+
"functional",
|
61 |
+
"immutable_id",
|
62 |
+
"energy",
|
63 |
+
]
|
64 |
+
display_names = {
|
65 |
+
"chemical_formula_descriptive": "Formula",
|
66 |
+
"functional": "Functional",
|
67 |
+
"immutable_id": "Material ID",
|
68 |
+
"energy": "Energy (eV)",
|
69 |
+
}
|
70 |
+
|
71 |
+
# Global shared variables
|
72 |
+
mapping_table_idx_dataset_idx = {}
|
73 |
+
|
74 |
+
|
75 |
+
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
|
76 |
+
if empty_data:
|
77 |
+
return np.zeros((1, 1)), {}
|
78 |
+
|
79 |
+
use_dataset = dataset
|
80 |
+
if index_range is not None:
|
81 |
+
use_dataset = dataset.select(index_range)
|
82 |
+
|
83 |
+
# Preprocessing step to create an index for the dataset
|
84 |
+
if cache_path is not None:
|
85 |
+
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
|
86 |
+
|
87 |
+
dataset_index = pickle.load(open(f"{cache_path}/dataset_index.pkl", "rb"))
|
88 |
+
else:
|
89 |
+
train_df = use_dataset.select_columns(
|
90 |
+
["chemical_formula_descriptive", "immutable_id"]
|
91 |
+
).to_pandas()
|
92 |
+
|
93 |
+
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
|
94 |
+
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
|
95 |
+
extracted["count"] = extracted["count"].replace("", "1").astype(int)
|
96 |
+
|
97 |
+
wide_df = (
|
98 |
+
extracted.reset_index().pivot_table( # Move index to columns for pivoting
|
99 |
+
index="level_0", # original row index
|
100 |
+
columns="element",
|
101 |
+
values="count",
|
102 |
+
aggfunc="sum",
|
103 |
+
fill_value=0,
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
all_elements = [el.symbol for el in periodictable.elements] # full element list
|
108 |
+
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
|
109 |
+
|
110 |
+
dataset_index = wide_df.values
|
111 |
+
|
112 |
+
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
113 |
+
dataset_index = (
|
114 |
+
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
115 |
+
) # Normalize vectors
|
116 |
+
|
117 |
+
immutable_id_to_idx = train_df["immutable_id"].to_dict()
|
118 |
+
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
|
119 |
+
|
120 |
+
return dataset_index, immutable_id_to_idx
|
121 |
+
|
122 |
+
|
123 |
+
import pickle
|
124 |
+
from pathlib import Path
|
125 |
+
|
126 |
+
|
127 |
+
# TODO: Just load the index from a file
|
128 |
+
def build_embeddings_index(empty_data=False):
|
129 |
+
if empty_data:
|
130 |
+
return None, {}, {}
|
131 |
+
|
132 |
+
features_dict = pickle.load(open("features_dict.pkl", "rb"))
|
133 |
+
|
134 |
+
from indexer import FAISSIndex
|
135 |
+
|
136 |
+
index = FAISSIndex()
|
137 |
+
for key in features_dict:
|
138 |
+
index.index.add(features_dict[key].reshape(1, -1))
|
139 |
+
|
140 |
+
idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)}
|
141 |
+
|
142 |
+
# index = FAISSIndex.from_store("index.faiss")
|
143 |
+
|
144 |
+
return index, features_dict, idx_to_immutable_id
|
145 |
+
|
146 |
+
|
147 |
+
def search_materials(
|
148 |
+
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
|
149 |
+
):
|
150 |
+
n_elements = len(map_periodic_table)
|
151 |
+
query_vector = np.zeros(n_elements)
|
152 |
+
|
153 |
+
if "," in query:
|
154 |
+
element_list = [el.strip() for el in query.split(",")]
|
155 |
+
for el in element_list:
|
156 |
+
query_vector[map_periodic_table[el]] = 1
|
157 |
+
else:
|
158 |
+
# Formula
|
159 |
+
import re
|
160 |
+
|
161 |
+
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
|
162 |
+
for el, numb in matches:
|
163 |
+
numb = int(numb) if numb else 1
|
164 |
+
query_vector[map_periodic_table[el]] = numb
|
165 |
+
|
166 |
+
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
|
167 |
+
indices = np.argsort(similarity)[::-1][:top_k]
|
168 |
+
|
169 |
+
options = [dataset[int(i)] for i in indices]
|
170 |
+
|
171 |
+
mapping_table_idx_dataset_idx.clear()
|
172 |
+
for i, idx in enumerate(indices):
|
173 |
+
mapping_table_idx_dataset_idx[int(i)] = int(idx)
|
174 |
+
|
175 |
+
return options
|
176 |
+
|
177 |
+
|
178 |
+
def get_properties_table(
|
179 |
+
row, structure, sga, properties_container_update, container_type="query"
|
180 |
+
):
|
181 |
+
properties = {
|
182 |
+
"Material ID": row["immutable_id"],
|
183 |
+
"Formula": row["chemical_formula_descriptive"],
|
184 |
+
"Energy per atom (eV/atom)": round(
|
185 |
+
row["energy"] / len(row["species_at_sites"]), 3
|
186 |
+
),
|
187 |
+
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
|
188 |
+
"Total Magnetization (μB)": (
|
189 |
+
round(row["total_magnetization"], 3)
|
190 |
+
if row["total_magnetization"] is not None
|
191 |
+
else None
|
192 |
+
),
|
193 |
+
"Density (g/cm^3)": round(structure.density, 3),
|
194 |
+
"Fermi energy level (eV)": (
|
195 |
+
round(row["dos_ef"], 3) if row["dos_ef"] is not None else None
|
196 |
+
),
|
197 |
+
"Crystal system": sga.get_crystal_system(),
|
198 |
+
"International Spacegroup": sga.get_symmetry_dataset().international,
|
199 |
+
"Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
|
200 |
+
"Stress tensor (kB)": np.round(row["stress_tensor"], 3),
|
201 |
+
"Forces on atoms (eV/A)": np.round(row["forces"], 3),
|
202 |
+
# "Bader charges (e-)": np.round(row["charges"], 3), # future release
|
203 |
+
"DFT Functional": row["functional"],
|
204 |
+
"Entalpic fingerprint": row["entalpic_fingerprint"],
|
205 |
+
}
|
206 |
+
|
207 |
+
style = {
|
208 |
+
"padding": "10px",
|
209 |
+
"borderBottom": "1px solid #ddd",
|
210 |
+
}
|
211 |
+
|
212 |
+
if container_type == "query":
|
213 |
+
properties_container_update[0] = properties
|
214 |
+
else:
|
215 |
+
properties_container_update[1] = properties
|
216 |
+
# if (type(value) in [str, float]) and (
|
217 |
+
# properties_container_update[0][key] == properties_container_update[1][key]
|
218 |
+
# ):
|
219 |
+
# style["backgroundColor"] = "#e6f7ff"
|
220 |
+
|
221 |
+
# Format properties as an HTML table
|
222 |
+
properties_html = html.Table(
|
223 |
+
[
|
224 |
+
html.Tbody(
|
225 |
+
[
|
226 |
+
html.Tr(
|
227 |
+
[
|
228 |
+
html.Th(
|
229 |
+
key,
|
230 |
+
style={
|
231 |
+
"padding": "10px",
|
232 |
+
"verticalAlign": "middle",
|
233 |
+
},
|
234 |
+
),
|
235 |
+
html.Td(
|
236 |
+
str(value),
|
237 |
+
style=style,
|
238 |
+
),
|
239 |
+
],
|
240 |
+
)
|
241 |
+
for key, value in properties.items()
|
242 |
+
],
|
243 |
+
)
|
244 |
+
],
|
245 |
+
style={
|
246 |
+
"width": "100%",
|
247 |
+
"borderCollapse": "collapse",
|
248 |
+
"fontFamily": "'Arial', sans-serif",
|
249 |
+
"fontSize": "14px",
|
250 |
+
"color": "#333333",
|
251 |
+
},
|
252 |
+
)
|
253 |
+
|
254 |
+
return properties_html
|
255 |
+
|
256 |
+
|
257 |
+
def get_crystal_plot(structure):
|
258 |
+
sga = SpacegroupAnalyzer(structure)
|
259 |
+
# Create the StructureMoleculeComponent
|
260 |
+
structure_component = ctc.StructureMoleculeComponent(structure)
|
261 |
+
return structure_component.layout(), sga
|