Ramlaoui commited on
Commit
901176a
1 Parent(s): 198fc98

More responsive

Browse files
Files changed (4) hide show
  1. app.py +59 -339
  2. assets/styles.css +124 -0
  3. components.py +184 -0
  4. data_utils.py +261 -0
app.py CHANGED
@@ -1,112 +1,73 @@
1
- import os
2
- import re
3
-
4
  import crystal_toolkit.components as ctc
5
  import dash
6
  import dash_mp_components as dmp
7
  import numpy as np
8
- import pandas as pd
9
  import periodictable
10
  from crystal_toolkit.settings import SETTINGS
11
  from dash import dcc, html
12
  from dash.dependencies import Input, Output, State
13
  from dash_breakpoints import WindowBreakpoints
14
- from datasets import concatenate_datasets, load_dataset
15
  from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
16
  from pymatgen.core import Structure
17
 
18
- HF_TOKEN = os.environ.get("HF_TOKEN")
19
- top_k = 500
20
-
21
- subsets = ["compatible_pbe", "compatible_pbesol", "compatible_scan", "non_compatible"]
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Load only the train split of the dataset
 
24
 
25
- datasets = []
26
- for subset in subsets:
27
- dataset = load_dataset(
28
- "LeMaterial/leMat-Bulk",
29
- subset,
30
- token=HF_TOKEN,
31
- columns=[
32
- "lattice_vectors",
33
- "species_at_sites",
34
- "cartesian_site_positions",
35
- "energy",
36
- # "energy_corrected", # not yet available in LeMat-Bulk
37
- "immutable_id",
38
- "elements",
39
- "functional",
40
- "stress_tensor",
41
- "magnetic_moments",
42
- "forces",
43
- # "band_gap_direct", #future release
44
- # "band_gap_indirect", #future release
45
- "dos_ef",
46
- # "charges", #future release
47
- "functional",
48
- "chemical_formula_reduced",
49
- "chemical_formula_descriptive",
50
- "total_magnetization",
51
- "entalpic_fingerprint"
52
- ],
53
- )
54
- datasets.append(dataset["train"])
55
 
56
- display_columns = [
57
  "chemical_formula_descriptive",
58
  "functional",
59
  "immutable_id",
60
  "energy",
61
  ]
62
- display_names = {
63
  "chemical_formula_descriptive": "Formula",
64
  "functional": "Functional",
65
  "immutable_id": "Material ID",
66
  "energy": "Energy (eV)",
67
  }
68
 
 
69
  mapping_table_idx_dataset_idx = {}
 
70
 
71
  map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
72
- n_elements = len(map_periodic_table)
73
 
74
- # Preprocessing step to create an index for the dataset
75
- # df = pd.concat([x.to_pandas() for x in datasets])
76
- dataset = concatenate_datasets(datasets)
77
- train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas()
78
-
79
- pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
80
- extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
81
- extracted["count"] = extracted["count"].replace("", "1").astype(int)
82
-
83
- wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
84
- index="level_0", # original row index
85
- columns="element",
86
- values="count",
87
- aggfunc="sum",
88
- fill_value=0,
89
  )
90
 
91
- all_elements = [el.symbol for el in periodictable.elements] # full element list
92
- wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
93
-
94
-
95
- dataset_index = wide_df.values
96
-
97
- dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
98
- dataset_index = (
99
- dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
100
- ) # Normalize vectors
101
-
102
- del train_df, extracted, wide_df
103
-
104
  # Initialize the Dash app
105
- app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
 
 
 
 
 
 
 
106
  server = app.server # Expose the server for deployment
107
 
108
  # Define the app layout
109
- layout = html.Div(
110
  [
111
  WindowBreakpoints(
112
  id="breakpoints",
@@ -119,178 +80,26 @@ layout = html.Div(
119
  ),
120
  html.Div(
121
  [
122
- html.Div(
123
- [
124
- html.Div(
125
- "Search a material to display its structure and properties",
126
- style={"textAlign": "center"},
127
- ),
128
- ],
129
- id="structure-container",
130
- style={
131
- "width": "44%",
132
- "verticalAlign": "top",
133
- "boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
134
- "borderRadius": "10px",
135
- "backgroundColor": "#f9f9f9",
136
- "padding": "20px",
137
- "textAlign": "center",
138
- "display": "flex",
139
- "justifyContent": "center",
140
- "alignItems": "center",
141
- },
142
- ),
143
- html.Div(
144
- id="properties-container",
145
- style={
146
- "width": "55%",
147
- "paddingLeft": "4%",
148
- "verticalAlign": "top",
149
- "boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
150
- "borderRadius": "10px",
151
- "backgroundColor": "#f9f9f9",
152
- "padding": "20px",
153
- "overflow": "auto",
154
- "maxHeight": "600px",
155
- "display": "flex",
156
- "justifyContent": "center",
157
- "wordWrap": "break-word",
158
- },
159
- children=[
160
- html.Div(
161
- "Properties will be displayed here",
162
- style={"textAlign": "center"},
163
- ),
164
- ],
165
- ),
166
  ],
167
- style={
168
- "marginTop": "20px",
169
- "display": "flex",
170
- "justifyContent": "space-between", # Ensure the two sections are responsive
171
- "flexWrap": "wrap",
172
- },
173
  ),
174
  html.Div(
175
  [
176
- html.Div(
177
- [
178
- html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
179
- html.Div(
180
- [
181
- html.Div(
182
- [
183
- dmp.MaterialsInput(
184
- allowedInputTypes=["elements", "formula"],
185
- hidePeriodicTable=False,
186
- periodicTableMode="toggle",
187
- hideWildcardButton=True,
188
- showSubmitButton=True,
189
- submitButtonText="Search",
190
- type="elements",
191
- id="materials-input",
192
- ),
193
- ],
194
- id="materials-input-container",
195
- style={
196
- "width": "100%",
197
- },
198
- ),
199
- ],
200
- style={
201
- "display": "flex",
202
- "justifyContent": "center",
203
- "width": "100%",
204
- },
205
- ),
206
- ],
207
- style={
208
- "width": "48%",
209
- "verticalAlign": "top",
210
- },
211
- ),
212
- html.Div(
213
- [
214
- html.Label(
215
- "Select a row to display the material's structure and properties",
216
- style={"margin-bottom": "20px"},
217
- ),
218
- # dcc.Dropdown(
219
- # id="material-dropdown",
220
- # options=[], # Empty options initially
221
- # value=None,
222
- # ),
223
- dash.dash_table.DataTable(
224
- id="table",
225
- columns=[
226
- (
227
- {"name": display_names[col], "id": col}
228
- if col != "energy"
229
- else {
230
- "name": display_names[col],
231
- "id": col,
232
- "type": "numeric",
233
- "format": {"specifier": ".2f"},
234
- }
235
- )
236
- for col in display_columns
237
- ],
238
- data=[{}],
239
- style_cell={
240
- "fontFamily": "Arial",
241
- "padding": "10px",
242
- "border": "1px solid #ddd", # Subtle border for elegance
243
- "textAlign": "left",
244
- "fontSize": "14px",
245
- },
246
- style_header={
247
- "backgroundColor": "#f5f5f5", # Light grey header
248
- "fontWeight": "bold",
249
- "textAlign": "left",
250
- "borderBottom": "2px solid #ddd",
251
- },
252
- style_data={
253
- "backgroundColor": "#ffffff",
254
- "color": "#333333",
255
- "borderBottom": "1px solid #ddd",
256
- },
257
- style_data_conditional=[
258
- {
259
- "if": {"state": "active"},
260
- "backgroundColor": "#e6f7ff",
261
- "border": "1px solid #1890ff",
262
- },
263
- ],
264
- style_table={
265
- "maxHeight": "400px",
266
- "overflowX": "auto",
267
- "overflowY": "auto",
268
- },
269
- style_as_list_view=True,
270
- row_selectable="single",
271
- selected_rows=[],
272
- ),
273
- ],
274
- style={
275
- "width": "48%",
276
- # "maxWidth": "800px",
277
- "margin": "0 auto",
278
- "padding": "20px",
279
- "backgroundColor": "#ffffff",
280
- "borderRadius": "10px",
281
- "boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
282
- },
283
  ),
284
  ],
285
- style={
286
- "margin-top": "20px",
287
- "margin-bottom": "20px",
288
- "display": "flex",
289
- "flexDirection": "row",
290
- "alignItems": "center",
291
- },
292
  ),
293
- # acknowledgements to mp dash components and crystal toolkit
294
  html.Footer(
295
  [
296
  html.P(
@@ -308,16 +117,6 @@ layout = html.Div(
308
  style={"textAlign": "center"},
309
  )
310
  ],
311
- style={
312
- "display": "flex",
313
- "justifyContent": "center",
314
- "alignItems": "center",
315
- "flexWrap": "wrap",
316
- "padding": "1rem 0",
317
- "backgroundColor": "#f1f1f1", # Optional: light gray footer background
318
- "borderTop": "1px solid #ddd", # Optional: subtle border at the top
319
- "width": "100%",
320
- },
321
  ),
322
  ],
323
  style={
@@ -327,34 +126,6 @@ layout = html.Div(
327
  )
328
 
329
 
330
- def search_materials(query):
331
- query_vector = np.zeros(n_elements)
332
-
333
- if "," in query:
334
- element_list = [el.strip() for el in query.split(",")]
335
- for el in element_list:
336
- query_vector[map_periodic_table[el]] = 1
337
- else:
338
- # Formula
339
- import re
340
-
341
- matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
342
- for el, numb in matches:
343
- numb = int(numb) if numb else 1
344
- query_vector[map_periodic_table[el]] = numb
345
-
346
- similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
347
- indices = np.argsort(similarity)[::-1][:top_k]
348
-
349
- options = [dataset[int(i)] for i in indices]
350
-
351
- mapping_table_idx_dataset_idx.clear()
352
- for i, idx in enumerate(indices):
353
- mapping_table_idx_dataset_idx[int(i)] = int(idx)
354
-
355
- return options
356
-
357
-
358
  # Callback to update the table based on search
359
  @app.callback(
360
  Output("table", "data"),
@@ -365,9 +136,11 @@ def on_submit_materials_input(n_clicks, query):
365
  if n_clicks is None or not query:
366
  return []
367
 
368
- entries = search_materials(query)
 
 
369
 
370
- return [{col: entry[col] for col in display_columns} for entry in entries]
371
 
372
 
373
  # Callback to display the selected material
@@ -376,7 +149,6 @@ def on_submit_materials_input(n_clicks, query):
376
  Output("structure-container", "children"),
377
  Output("properties-container", "children"),
378
  ],
379
- # Input("display-button", "n_clicks"),
380
  Input("table", "active_cell"),
381
  Input("table", "derived_virtual_selected_rows"),
382
  )
@@ -408,69 +180,17 @@ def display_material(active_cell, selected_rows):
408
  if row["magnetic_moments"]:
409
  structure.add_site_property("magmom", row["magnetic_moments"])
410
 
411
- sga = SpacegroupAnalyzer(structure)
412
-
413
- # Create the StructureMoleculeComponent
414
- structure_component = ctc.StructureMoleculeComponent(structure)
415
 
416
  # Extract key properties
417
- properties = {
418
- "Material ID": row["immutable_id"],
419
- "Formula": row["chemical_formula_descriptive"],
420
- "Energy per atom (eV/atom)": round(
421
- row["energy"] / len(row["species_at_sites"]), 3
422
- ),
423
- # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
424
- "Total Magnetization (μB)": round(row["total_magnetization"], 3) if row['total_magnetization'] is not None else None,
425
- "Density (g/cm^3)": round(structure.density, 3),
426
- "Fermi energy level (eV)": round(row["dos_ef"],3) if row['dos_ef'] is not None else None,
427
- "Crystal system": sga.get_crystal_system(),
428
- "International Spacegroup": sga.get_symmetry_dataset().international,
429
- "Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
430
- "Stress tensor (kB)": np.round(row["stress_tensor"], 3),
431
- "Forces on atoms (eV/A)": np.round(row["forces"], 3),
432
- # "Bader charges (e-)": np.round(row["charges"], 3), # future release
433
- "DFT Functional": row["functional"],
434
- "Entalpic fingerprint": row['entalpic_fingerprint'],
435
- }
436
-
437
- # Format properties as an HTML table
438
- properties_html = html.Table(
439
- [
440
- html.Tbody(
441
- [
442
- html.Tr(
443
- [
444
- html.Th(
445
- key,
446
- style={
447
- "padding": "10px",
448
- "verticalAlign": "middle",
449
- },
450
- ),
451
- html.Td(
452
- str(value),
453
- style={
454
- "padding": "10px",
455
- "borderBottom": "1px solid #ddd",
456
- },
457
- ),
458
- ],
459
- )
460
- for key, value in properties.items()
461
- ],
462
- )
463
- ],
464
- style={
465
- "width": "100%",
466
- "borderCollapse": "collapse",
467
- "fontFamily": "'Arial', sans-serif",
468
- "fontSize": "14px",
469
- "color": "#333333",
470
- },
471
  )
472
 
473
- return structure_component.layout(), properties_html
 
 
 
474
 
475
 
476
  @app.callback(
@@ -505,7 +225,7 @@ def update_materials_input_layout(breakpoint_name, width):
505
 
506
 
507
  # Register crystal toolkit with the app
508
- ctc.register_crystal_toolkit(app, layout)
509
 
510
  if __name__ == "__main__":
511
  app.run_server(debug=True, port=7860, host="0.0.0.0")
 
 
 
 
1
  import crystal_toolkit.components as ctc
2
  import dash
3
  import dash_mp_components as dmp
4
  import numpy as np
 
5
  import periodictable
6
  from crystal_toolkit.settings import SETTINGS
7
  from dash import dcc, html
8
  from dash.dependencies import Input, Output, State
9
  from dash_breakpoints import WindowBreakpoints
 
10
  from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
11
  from pymatgen.core import Structure
12
 
13
+ from components import (
14
+ get_display_table,
15
+ get_dropdown,
16
+ get_materials_display,
17
+ get_periodic_table,
18
+ get_upload_div,
19
+ )
20
+ from data_utils import (
21
+ build_embeddings_index,
22
+ build_formula_index,
23
+ get_crystal_plot,
24
+ get_dataset,
25
+ get_properties_table,
26
+ search_materials,
27
+ )
28
 
29
+ EMPTY_DATA = False
30
+ CACHE_PATH = None
31
 
32
+ dataset = get_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ display_columns_query = [
35
  "chemical_formula_descriptive",
36
  "functional",
37
  "immutable_id",
38
  "energy",
39
  ]
40
+ display_names_query = {
41
  "chemical_formula_descriptive": "Formula",
42
  "functional": "Functional",
43
  "immutable_id": "Material ID",
44
  "energy": "Energy (eV)",
45
  }
46
 
47
+
48
  mapping_table_idx_dataset_idx = {}
49
+ available_similar_materials = []
50
 
51
  map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
 
52
 
53
+ # dataset_index, immutable_id_to_idx = build_formula_index(dataset, cache_path=None)
54
+ dataset_index, immutable_id_to_idx = build_formula_index(
55
+ dataset, cache_path=CACHE_PATH, empty_data=EMPTY_DATA
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Initialize the Dash app
59
+ external_stylesheets = [
60
+ "/assets/styles.css",
61
+ ]
62
+
63
+ app = dash.Dash(
64
+ __name__,
65
+ external_stylesheets=external_stylesheets,
66
+ )
67
  server = app.server # Expose the server for deployment
68
 
69
  # Define the app layout
70
+ app.layout = html.Div(
71
  [
72
  WindowBreakpoints(
73
  id="breakpoints",
 
80
  ),
81
  html.Div(
82
  [
83
+ get_materials_display(
84
+ "",
85
+ "Structure will be displayed here",
86
+ "Properties will be displayed here",
87
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  ],
89
+ className="container-row",
 
 
 
 
 
90
  ),
91
  html.Div(
92
  [
93
+ get_periodic_table("materials-input", {}),
94
+ get_display_table(
95
+ "table",
96
+ display_names_query,
97
+ display_columns_query,
98
+ "Select a row to display the material's structure and properties",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ),
100
  ],
101
+ className="container-row-periodic",
 
 
 
 
 
 
102
  ),
 
103
  html.Footer(
104
  [
105
  html.P(
 
117
  style={"textAlign": "center"},
118
  )
119
  ],
 
 
 
 
 
 
 
 
 
 
120
  ),
121
  ],
122
  style={
 
126
  )
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # Callback to update the table based on search
130
  @app.callback(
131
  Output("table", "data"),
 
136
  if n_clicks is None or not query:
137
  return []
138
 
139
+ entries = search_materials(
140
+ query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
141
+ )
142
 
143
+ return [{col: entry[col] for col in display_columns_query} for entry in entries]
144
 
145
 
146
  # Callback to display the selected material
 
149
  Output("structure-container", "children"),
150
  Output("properties-container", "children"),
151
  ],
 
152
  Input("table", "active_cell"),
153
  Input("table", "derived_virtual_selected_rows"),
154
  )
 
180
  if row["magnetic_moments"]:
181
  structure.add_site_property("magmom", row["magnetic_moments"])
182
 
183
+ structure_layout, sga = get_crystal_plot(structure)
 
 
 
184
 
185
  # Extract key properties
186
+ properties_html = get_properties_table(
187
+ row, structure, sga, [None, None], container_type="results"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
189
 
190
+ return (
191
+ structure_layout,
192
+ properties_html,
193
+ )
194
 
195
 
196
  @app.callback(
 
225
 
226
 
227
  # Register crystal toolkit with the app
228
+ ctc.register_crystal_toolkit(app, app.layout)
229
 
230
  if __name__ == "__main__":
231
  app.run_server(debug=True, port=7860, host="0.0.0.0")
assets/styles.css ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ font-size: 24px;
3
+ font-weight: 700;
4
+ color: #333;
5
+ }
6
+
7
+ .body {
8
+ background-color: #4a4a4a;
9
+ }
10
+
11
+ .header-container {
12
+ display: flex;
13
+ flex-direction: row;
14
+ justify-content: space-between;
15
+ align-items: center;
16
+ padding: 20px;
17
+ margin-bottom: 20px;
18
+ }
19
+
20
+ .container {
21
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
22
+ border-radius: 10px;
23
+ background-color: rgb(249, 249, 249);
24
+ padding: 20px;
25
+ margin-left: 10px;
26
+ margin-right: 10px;
27
+ max-height: 600px;
28
+ justify-content: center;
29
+ }
30
+
31
+ .container-visu {
32
+ width: 45%;
33
+ align-items: center;
34
+ }
35
+
36
+ .container-table {
37
+ width: 50%;
38
+ align-items: center;
39
+ overflow: auto;
40
+ }
41
+
42
+ /* remove background in periodical table */
43
+ .periodic-table {
44
+ background-color: transparent;
45
+ box-shadow: none;
46
+ }
47
+
48
+ .container-row {
49
+ width: 100%;
50
+ display: flex;
51
+ flex-direction: row;
52
+ justify-content: space-between;
53
+ padding: 10px;
54
+ margin-bottom: 10px;
55
+ margin-top: 10px;
56
+ }
57
+
58
+ .container-row-periodic {
59
+ width: 100%;
60
+ display: flex;
61
+ flex-direction: row;
62
+ align-items: center;
63
+ }
64
+
65
+ .container-col {
66
+ width: 100%;
67
+ display: flex;
68
+ flex-direction: column;
69
+ justify-content: space-between;
70
+ padding: 10px;
71
+ margin-bottom: 10px;
72
+ margin-top: 10px;
73
+ }
74
+
75
+ body {
76
+ font-family: "Arial", sans-serif;
77
+ font-size: 16px;
78
+ }
79
+
80
+ @media (max-width: 800px) {
81
+ .container {
82
+ width: 100%;
83
+ margin: 5px;
84
+ margin-top: 10px;
85
+ margin-bottom: 10px;
86
+ }
87
+
88
+ .container-row {
89
+ flex-direction: column;
90
+ }
91
+
92
+ .container-row-periodic {
93
+ flex-direction: column;
94
+ }
95
+
96
+ .container-visu {
97
+ width: 100%;
98
+ }
99
+
100
+ .container-table {
101
+ width: 100%;
102
+ }
103
+ }
104
+
105
+ @media (max-width: 1000px) and (min-width: 800px) {
106
+ .container-visu {
107
+ width: 60%;
108
+ }
109
+
110
+ .container-table {
111
+ width: 39%;
112
+ }
113
+ }
114
+
115
+ footer {
116
+ display: flex;
117
+ justify-content: center;
118
+ align-items: center;
119
+ flex-wrap: wrap;
120
+ margin-top: 40px;
121
+ background-color: #ffffff;
122
+ border-top: "1px solid #ddd";
123
+ width: 100%;
124
+ }
components.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ import dash_mp_components as dmp
3
+ from dash import dcc, html
4
+
5
+ display_columns = [
6
+ "chemical_formula_descriptive",
7
+ "functional",
8
+ "immutable_id",
9
+ "energy",
10
+ ]
11
+
12
+ display_names = {
13
+ "chemical_formula_descriptive": "Formula",
14
+ "functional": "Functional",
15
+ "immutable_id": "Material ID",
16
+ "energy": "Energy (eV)",
17
+ }
18
+
19
+
20
+ def get_periodic_table(id, table_kwargs, **style_kwargs):
21
+
22
+ return html.Div(
23
+ [
24
+ html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
25
+ html.Div(
26
+ [
27
+ dmp.MaterialsInput(
28
+ allowedInputTypes=[
29
+ "elements",
30
+ "formula",
31
+ ],
32
+ hidePeriodicTable=False,
33
+ periodicTableMode="toggle",
34
+ hideWildcardButton=True,
35
+ showSubmitButton=True,
36
+ submitButtonText="Search",
37
+ type="elements",
38
+ **table_kwargs,
39
+ id=id,
40
+ ),
41
+ ],
42
+ id="materials-input-container",
43
+ style={
44
+ "width": "100%",
45
+ },
46
+ ),
47
+ ],
48
+ className="container periodic-table",
49
+ )
50
+
51
+
52
+ def get_dropdown(id, options, **style_kwargs):
53
+ return dcc.Dropdown(
54
+ id=id,
55
+ options=options,
56
+ placeholder="Embedder",
57
+ value=None,
58
+ clearable=False,
59
+ style=style_kwargs,
60
+ )
61
+
62
+
63
+ def get_upload_div(id, **style_kwargs):
64
+ return html.Div(
65
+ [
66
+ html.H3("Upload a CIF file"),
67
+ dcc.Upload(
68
+ id=id,
69
+ children=html.Div(
70
+ [
71
+ "Drag and Drop or ",
72
+ html.A("Select a CIF file"),
73
+ ]
74
+ ),
75
+ style={
76
+ "width": "100%",
77
+ "height": "60px",
78
+ "lineHeight": "60px",
79
+ "borderWidth": "1px",
80
+ "borderStyle": "dashed",
81
+ "borderRadius": "5px",
82
+ "textAlign": "center",
83
+ "margin": "10px",
84
+ },
85
+ multiple=False,
86
+ ),
87
+ ],
88
+ className="container",
89
+ )
90
+
91
+
92
+ def get_display_table(id, display_names, display_columns, text, **style_kwargs):
93
+
94
+ return html.Div(
95
+ [
96
+ html.Label(
97
+ text,
98
+ style={"margin-bottom": "20px"},
99
+ ),
100
+ dash.dash_table.DataTable(
101
+ id=id,
102
+ columns=[
103
+ (
104
+ {
105
+ "name": display_names[col],
106
+ "id": col,
107
+ }
108
+ if col != "energy"
109
+ else {
110
+ "name": display_names[col],
111
+ "id": col,
112
+ "type": "numeric",
113
+ "format": {"specifier": ".2f"},
114
+ }
115
+ )
116
+ for col in display_columns
117
+ ],
118
+ data=[{}],
119
+ style_cell={
120
+ "fontFamily": "Arial",
121
+ "padding": "10px",
122
+ "border": "1px solid #ddd", # Subtle border for elegance
123
+ "textAlign": "left",
124
+ "fontSize": "14px",
125
+ },
126
+ style_header={
127
+ "backgroundColor": "#f5f5f5", # Light grey header
128
+ "fontWeight": "bold",
129
+ "textAlign": "left",
130
+ "borderBottom": "2px solid #ddd",
131
+ },
132
+ style_data={
133
+ "backgroundColor": "#ffffff",
134
+ "color": "#333333",
135
+ "borderBottom": "1px solid #ddd",
136
+ },
137
+ style_data_conditional=[
138
+ {
139
+ "if": {"state": "active"},
140
+ "backgroundColor": "#e6f7ff",
141
+ "border": "1px solid #1890ff",
142
+ },
143
+ ],
144
+ style_table={
145
+ "maxHeight": "400px",
146
+ "overflowX": "auto",
147
+ "overflowY": "auto",
148
+ },
149
+ style_as_list_view=True,
150
+ row_selectable="single",
151
+ selected_rows=[],
152
+ ),
153
+ ],
154
+ className="container",
155
+ )
156
+
157
+
158
+ def get_materials_display(id, text_materials_div, text_table_div, **style_kwargs):
159
+ return html.Div(
160
+ [
161
+ html.Div(
162
+ [
163
+ html.Div(
164
+ text_materials_div,
165
+ style={"textAlign": "center"},
166
+ ),
167
+ ],
168
+ id=f"structure-container{id}",
169
+ className="container container-visu",
170
+ ),
171
+ html.Div(
172
+ id=f"properties-container{id}",
173
+ className="container container-table",
174
+ style={"width": "100%"},
175
+ children=[
176
+ html.Div(
177
+ text_table_div,
178
+ style={"textAlign": "center"},
179
+ ),
180
+ ],
181
+ ),
182
+ ],
183
+ className="container-row",
184
+ )
data_utils.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import crystal_toolkit.components as ctc
5
+ import numpy as np
6
+ import periodictable
7
+ from dash import dcc, html
8
+ from datasets import concatenate_datasets, load_dataset
9
+ from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
10
+
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
+ top_k = 500
13
+
14
+
15
+ def get_dataset():
16
+ # Load only the train split of the dataset
17
+ datasets = []
18
+ subsets = [
19
+ "compatible_pbe",
20
+ "compatible_pbesol",
21
+ "compatible_scan",
22
+ "non_compatible",
23
+ ]
24
+
25
+ for subset in subsets:
26
+ dataset = load_dataset(
27
+ "LeMaterial/leMat-Bulk",
28
+ subset,
29
+ token=HF_TOKEN,
30
+ columns=[
31
+ "lattice_vectors",
32
+ "species_at_sites",
33
+ "cartesian_site_positions",
34
+ "energy",
35
+ # "energy_corrected", # not yet available in LeMat-Bulk
36
+ "immutable_id",
37
+ "elements",
38
+ "functional",
39
+ "stress_tensor",
40
+ "magnetic_moments",
41
+ "forces",
42
+ # "band_gap_direct", #future release
43
+ # "band_gap_indirect", #future release
44
+ "dos_ef",
45
+ # "charges", #future release
46
+ "functional",
47
+ "chemical_formula_reduced",
48
+ "chemical_formula_descriptive",
49
+ "total_magnetization",
50
+ "entalpic_fingerprint",
51
+ ],
52
+ )
53
+ datasets.append(dataset["train"])
54
+
55
+ return concatenate_datasets(datasets)
56
+
57
+
58
+ display_columns = [
59
+ "chemical_formula_descriptive",
60
+ "functional",
61
+ "immutable_id",
62
+ "energy",
63
+ ]
64
+ display_names = {
65
+ "chemical_formula_descriptive": "Formula",
66
+ "functional": "Functional",
67
+ "immutable_id": "Material ID",
68
+ "energy": "Energy (eV)",
69
+ }
70
+
71
+ # Global shared variables
72
+ mapping_table_idx_dataset_idx = {}
73
+
74
+
75
+ def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
76
+ if empty_data:
77
+ return np.zeros((1, 1)), {}
78
+
79
+ use_dataset = dataset
80
+ if index_range is not None:
81
+ use_dataset = dataset.select(index_range)
82
+
83
+ # Preprocessing step to create an index for the dataset
84
+ if cache_path is not None:
85
+ train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
86
+
87
+ dataset_index = pickle.load(open(f"{cache_path}/dataset_index.pkl", "rb"))
88
+ else:
89
+ train_df = use_dataset.select_columns(
90
+ ["chemical_formula_descriptive", "immutable_id"]
91
+ ).to_pandas()
92
+
93
+ pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
94
+ extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
95
+ extracted["count"] = extracted["count"].replace("", "1").astype(int)
96
+
97
+ wide_df = (
98
+ extracted.reset_index().pivot_table( # Move index to columns for pivoting
99
+ index="level_0", # original row index
100
+ columns="element",
101
+ values="count",
102
+ aggfunc="sum",
103
+ fill_value=0,
104
+ )
105
+ )
106
+
107
+ all_elements = [el.symbol for el in periodictable.elements] # full element list
108
+ wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
109
+
110
+ dataset_index = wide_df.values
111
+
112
+ dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
113
+ dataset_index = (
114
+ dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
115
+ ) # Normalize vectors
116
+
117
+ immutable_id_to_idx = train_df["immutable_id"].to_dict()
118
+ immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
119
+
120
+ return dataset_index, immutable_id_to_idx
121
+
122
+
123
+ import pickle
124
+ from pathlib import Path
125
+
126
+
127
+ # TODO: Just load the index from a file
128
+ def build_embeddings_index(empty_data=False):
129
+ if empty_data:
130
+ return None, {}, {}
131
+
132
+ features_dict = pickle.load(open("features_dict.pkl", "rb"))
133
+
134
+ from indexer import FAISSIndex
135
+
136
+ index = FAISSIndex()
137
+ for key in features_dict:
138
+ index.index.add(features_dict[key].reshape(1, -1))
139
+
140
+ idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)}
141
+
142
+ # index = FAISSIndex.from_store("index.faiss")
143
+
144
+ return index, features_dict, idx_to_immutable_id
145
+
146
+
147
+ def search_materials(
148
+ query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
149
+ ):
150
+ n_elements = len(map_periodic_table)
151
+ query_vector = np.zeros(n_elements)
152
+
153
+ if "," in query:
154
+ element_list = [el.strip() for el in query.split(",")]
155
+ for el in element_list:
156
+ query_vector[map_periodic_table[el]] = 1
157
+ else:
158
+ # Formula
159
+ import re
160
+
161
+ matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
162
+ for el, numb in matches:
163
+ numb = int(numb) if numb else 1
164
+ query_vector[map_periodic_table[el]] = numb
165
+
166
+ similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
167
+ indices = np.argsort(similarity)[::-1][:top_k]
168
+
169
+ options = [dataset[int(i)] for i in indices]
170
+
171
+ mapping_table_idx_dataset_idx.clear()
172
+ for i, idx in enumerate(indices):
173
+ mapping_table_idx_dataset_idx[int(i)] = int(idx)
174
+
175
+ return options
176
+
177
+
178
+ def get_properties_table(
179
+ row, structure, sga, properties_container_update, container_type="query"
180
+ ):
181
+ properties = {
182
+ "Material ID": row["immutable_id"],
183
+ "Formula": row["chemical_formula_descriptive"],
184
+ "Energy per atom (eV/atom)": round(
185
+ row["energy"] / len(row["species_at_sites"]), 3
186
+ ),
187
+ # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
188
+ "Total Magnetization (μB)": (
189
+ round(row["total_magnetization"], 3)
190
+ if row["total_magnetization"] is not None
191
+ else None
192
+ ),
193
+ "Density (g/cm^3)": round(structure.density, 3),
194
+ "Fermi energy level (eV)": (
195
+ round(row["dos_ef"], 3) if row["dos_ef"] is not None else None
196
+ ),
197
+ "Crystal system": sga.get_crystal_system(),
198
+ "International Spacegroup": sga.get_symmetry_dataset().international,
199
+ "Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
200
+ "Stress tensor (kB)": np.round(row["stress_tensor"], 3),
201
+ "Forces on atoms (eV/A)": np.round(row["forces"], 3),
202
+ # "Bader charges (e-)": np.round(row["charges"], 3), # future release
203
+ "DFT Functional": row["functional"],
204
+ "Entalpic fingerprint": row["entalpic_fingerprint"],
205
+ }
206
+
207
+ style = {
208
+ "padding": "10px",
209
+ "borderBottom": "1px solid #ddd",
210
+ }
211
+
212
+ if container_type == "query":
213
+ properties_container_update[0] = properties
214
+ else:
215
+ properties_container_update[1] = properties
216
+ # if (type(value) in [str, float]) and (
217
+ # properties_container_update[0][key] == properties_container_update[1][key]
218
+ # ):
219
+ # style["backgroundColor"] = "#e6f7ff"
220
+
221
+ # Format properties as an HTML table
222
+ properties_html = html.Table(
223
+ [
224
+ html.Tbody(
225
+ [
226
+ html.Tr(
227
+ [
228
+ html.Th(
229
+ key,
230
+ style={
231
+ "padding": "10px",
232
+ "verticalAlign": "middle",
233
+ },
234
+ ),
235
+ html.Td(
236
+ str(value),
237
+ style=style,
238
+ ),
239
+ ],
240
+ )
241
+ for key, value in properties.items()
242
+ ],
243
+ )
244
+ ],
245
+ style={
246
+ "width": "100%",
247
+ "borderCollapse": "collapse",
248
+ "fontFamily": "'Arial', sans-serif",
249
+ "fontSize": "14px",
250
+ "color": "#333333",
251
+ },
252
+ )
253
+
254
+ return properties_html
255
+
256
+
257
+ def get_crystal_plot(structure):
258
+ sga = SpacegroupAnalyzer(structure)
259
+ # Create the StructureMoleculeComponent
260
+ structure_component = ctc.StructureMoleculeComponent(structure)
261
+ return structure_component.layout(), sga