SLICES / app.py
xiaohang07's picture
Update app.py
8f98f3c verified
raw
history blame contribute delete
No virus
5.98 kB
import gradio as gr
from slices.core import SLICES
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.io.ase import AseAtomsAdaptor
from ase.io import write as ase_write
import tempfile
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
import os
# Initialize SLICES backend
backend = SLICES(relax_model="chgnet", fmax=0.4, steps=25)
def wrap_structure(structure):
"""Wrap all atoms back into the unit cell."""
for i, site in enumerate(structure):
frac_coords = site.frac_coords % 1.0
structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
return structure
def get_primitive_structure(structure):
"""Convert the structure to its primitive cell."""
analyzer = SpacegroupAnalyzer(structure)
return analyzer.get_primitive_standard_structure()
def visualize_structure(structure):
"""Generate an image of the structure."""
atoms = AseAtomsAdaptor.get_atoms(structure)
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
ase_write(temp_file.name, atoms, format='png', rotation='10x,10y,10z')
return temp_file.name
def process_structure(structure):
"""Wrap and convert to primitive cell."""
structure = wrap_structure(structure)
return get_primitive_structure(structure)
def cif_to_slices(cif_file):
try:
structure = Structure.from_file(cif_file.name)
structure = process_structure(structure)
slices_string = backend.structure2SLICES(structure)
image_file = visualize_structure(structure)
return slices_string, image_file, None, slices_string, slices_string # Added another slices_string for aug_slices_input
except Exception as e:
return str(e), None, None, "", ""
def slices_to_cif(slices_string):
try:
structure, energy = backend.SLICES2structure(slices_string)
structure = process_structure(structure)
with tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False) as temp_file:
CifWriter(structure).write_file(temp_file.name)
image_file = visualize_structure(structure)
return temp_file.name, image_file, f"Conversion successful. Energy: {energy:.4f} eV/atom"
except Exception as e:
return None, None, f"Conversion failed. Error: {str(e)}"
def augment_and_canonicalize_slices(slices_string, num_augmentations):
try:
augmented_slices = backend.SLICES2SLICESAug_atom_order(slices_string,num=num_augmentations)
unique_augmented_slices = list(set(augmented_slices))
canonical_slices = list(set([backend.get_canonical_SLICES(s) for s in unique_augmented_slices]))
return augmented_slices, canonical_slices
except Exception as e:
return [], [], str(e)
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Crystal Structure and SLICES Converter", elem_classes=["center"])
with gr.Row(elem_classes=["center"]):
gr.Image("1.png", label="SLICES Representation", show_label=False, width=600, height=250)
gr.Markdown("SLICES provides a text-based encoding of crystal structures, allowing for efficient manipulation and generation of new materials.", elem_classes=["center"])
with gr.Tab("CIF-SLICES Conversion"):
with gr.Row():
with gr.Column():
file_choice = gr.Radio(
["Use example CIF (NdSiRu.cif)", "Upload custom CIF"],
label="Choose CIF source",
value="Use example CIF (NdSiRu.cif)"
)
example_file = gr.File(value="NdSiRu.cif", visible=False, interactive=False)
custom_file = gr.File(label="Upload CIF file", file_types=[".cif"], visible=False)
convert_cif_button = gr.Button("Convert CIF to SLICES")
slices_input = gr.Textbox(label="Enter SLICES String")
convert_slices_button = gr.Button("Convert SLICES to CIF")
with gr.Column():
slices_output = gr.Textbox(label="SLICES String")
cif_output = gr.File(label="Download CIF", file_types=[".cif"])
conversion_status = gr.Textbox(label="Conversion Status")
with gr.Row():
cif_image = gr.Image(label="Original Structure")
slices_image = gr.Image(label="Converted Structure")
with gr.Tab("SLICES Augmentation and Canonicalization"):
aug_slices_input = gr.Textbox(label="Enter SLICES String")
num_augmentations = gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Number of Augmentations")
augment_button = gr.Button("Augment and Canonicalize")
aug_slices_output = gr.Textbox(label="Augmented SLICES Strings")
canon_slices_output = gr.Textbox(label="Canonical SLICES Strings")
# Event handlers
def update_file_visibility(choice):
return gr.update(visible=choice == "Use example CIF (NdSiRu.cif)"), gr.update(visible=choice == "Upload custom CIF")
file_choice.change(
update_file_visibility,
inputs=[file_choice],
outputs=[example_file, custom_file]
)
def get_active_file(choice, example, custom):
return example if choice == "Use example CIF (NdSiRu.cif)" else custom
convert_cif_button.click(
lambda choice, example, custom: cif_to_slices(get_active_file(choice, example, custom)),
inputs=[file_choice, example_file, custom_file],
outputs=[slices_output, cif_image, conversion_status, slices_input, aug_slices_input]
)
convert_slices_button.click(
slices_to_cif,
inputs=[slices_input],
outputs=[cif_output, slices_image, conversion_status]
)
augment_button.click(
augment_and_canonicalize_slices,
inputs=[aug_slices_input, num_augmentations],
outputs=[aug_slices_output, canon_slices_output]
)
iface.launch(share=True)