jannisborn commited on
Commit
19e399c
1 Parent(s): c83fa31
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.DS_Store
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Crystal properties (CGCNN)
3
  emoji: 💡
4
  colorFrom: green
5
  colorTo: blue
 
1
  ---
2
+ title: Crystal properties
3
  emoji: 💡
4
  colorFrom: green
5
  colorTo: blue
app.py CHANGED
@@ -1,169 +1,107 @@
1
  import logging
 
2
  import pathlib
 
 
 
3
 
4
  import gradio as gr
5
  import pandas as pd
6
- from gt4sd.algorithms.conditional_generation.regression_transformer import (
7
- RegressionTransformer,
8
- )
9
- from gt4sd.algorithms.registry import ApplicationsRegistry
10
- from utils import (
11
- draw_grid_generate,
12
- draw_grid_predict,
13
- get_application,
14
- get_inference_dict,
15
- get_rt_name,
16
- )
17
 
18
  logger = logging.getLogger(__name__)
19
  logger.addHandler(logging.NullHandler())
20
 
 
21
 
22
- def regression_transformer(
23
- algorithm: str,
24
- task: str,
25
- target: str,
26
- number_of_samples: int,
27
- search: str,
28
- temperature: float,
29
- tolerance: int,
30
- wrapper: bool,
31
- fraction_to_mask: float,
32
- property_goal: str,
33
- tokens_to_mask: str,
34
- substructures_to_mask: str,
35
- substructures_to_keep: str,
36
- ):
37
-
38
- if task == "Predict" and wrapper:
39
- logger.warning(
40
- f"For prediction, no sampling_wrapper will be used, ignoring: fraction_to_mask: {fraction_to_mask}, "
41
- f"tokens_to_mask: {tokens_to_mask}, substructures_to_mask={substructures_to_mask}, "
42
- f"substructures_to_keep: {substructures_to_keep}."
43
- )
44
- sampling_wrapper = {}
45
- elif not wrapper:
46
- sampling_wrapper = {}
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
- substructures_to_mask = (
49
- []
50
- if substructures_to_mask == ""
51
- else substructures_to_mask.replace(" ", "").split(",")
52
  )
53
- substructures_to_keep = (
54
- []
55
- if substructures_to_keep == ""
56
- else substructures_to_keep.replace(" ", "").split(",")
57
- )
58
- tokens_to_mask = [] if tokens_to_mask == "" else tokens_to_mask.split(",")
59
-
60
- property_goals = {}
61
- if property_goal == "":
62
- raise ValueError(
63
- "For conditional generation you have to specify `property_goal`."
64
- )
65
- for line in property_goal.split(","):
66
- property_goals[line.split(":")[0].strip()] = float(line.split(":")[1])
67
-
68
- sampling_wrapper = {
69
- "substructures_to_keep": substructures_to_keep,
70
- "substructures_to_mask": substructures_to_mask,
71
- "text_filtering": False,
72
- "fraction_to_mask": fraction_to_mask,
73
- "property_goal": property_goals,
74
- }
75
- algorithm_application = get_application(algorithm.split(":")[0])
76
- algorithm_version = algorithm.split(" ")[-1].lower()
77
- config = algorithm_application(
78
- algorithm_version=algorithm_version,
79
- search=search.lower(),
80
- temperature=temperature,
81
- tolerance=tolerance,
82
- sampling_wrapper=sampling_wrapper,
83
- )
84
- model = RegressionTransformer(configuration=config, target=target)
85
- samples = list(model.sample(number_of_samples))
86
- if algorithm_version == "polymer" and task == "Generate":
87
- correct_samples = [(s, p) for s, p in samples if "." in s]
88
- while len(correct_samples) < number_of_samples:
89
- samples = list(model.sample(number_of_samples))
90
- correct_samples.extend(
91
- [
92
- (s, p)
93
- for s, p in samples
94
- if "." in s and (s, p) not in correct_samples
95
- ]
96
- )
97
- samples = correct_samples
98
- if task == "Predict":
99
- return draw_grid_predict(samples[0], target, domain=algorithm.split(":")[0])
100
- else:
101
- return draw_grid_generate(samples, domain=algorithm.split(":")[0])
102
 
103
 
104
  if __name__ == "__main__":
105
 
106
  # Preparation (retrieve all available algorithms)
107
- all_algos = ApplicationsRegistry.list_available()
108
- rt_algos = list(
109
- filter(lambda x: "RegressionTransformer" in x["algorithm_name"], all_algos)
110
- )
111
- rt_names = list(map(get_rt_name, rt_algos))
112
-
113
- properties = {}
114
- for algo in rt_algos:
115
- application = get_application(
116
- algo["algorithm_application"].split("Transformer")[-1]
117
- )
118
- data = get_inference_dict(
119
- application=application, algorithm_version=algo["algorithm_version"]
120
- )
121
- properties[get_rt_name(algo)] = data
122
- properties
123
 
124
  # Load metadata
125
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
126
 
127
- examples = pd.read_csv(
128
- metadata_root.joinpath("regression_transformer_examples.csv"), header=None
129
- ).fillna("")
 
 
 
130
 
131
- with open(metadata_root.joinpath("regression_transformer_article.md"), "r") as f:
132
  article = f.read()
133
- with open(
134
- metadata_root.joinpath("regression_transformer_description.md"), "r"
135
- ) as f:
136
  description = f.read()
137
 
138
  demo = gr.Interface(
139
- fn=regression_transformer,
140
- title="Regression Transformer",
141
  inputs=[
142
- gr.Dropdown(rt_names, label="Algorithm version", value="Molecules: Qed"),
143
- gr.Radio(choices=["Predict", "Generate"], label="Task", value="Generate"),
144
- gr.Textbox(
145
- label="Input", placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1", lines=1
146
- ),
147
- gr.Slider(
148
- minimum=1, maximum=50, value=10, label="Number of samples", step=1
149
- ),
150
- gr.Radio(choices=["Sample", "Greedy"], label="Search", value="Sample"),
151
- gr.Slider(minimum=0.5, maximum=2, value=1, label="Decoding temperature"),
152
- gr.Slider(minimum=5, maximum=100, value=30, label="Tolerance", step=1),
153
- gr.Radio(choices=[True, False], label="Sampling Wrapper", value=True),
154
- gr.Slider(minimum=0, maximum=1, value=0.5, label="Fraction to mask"),
155
- gr.Textbox(label="Property goal", placeholder="<qed>:0.75", lines=1),
156
- gr.Textbox(label="Tokens to mask", placeholder="N, C", lines=1),
157
- gr.Textbox(
158
- label="Substructures to mask", placeholder="C(=O), C#C", lines=1
159
- ),
160
- gr.Textbox(
161
- label="Substructures to keep", placeholder="C1=CC=C(Cl)C=C1", lines=1
162
  ),
163
  ],
164
- outputs=gr.HTML(label="Output"),
165
  article=article,
166
  description=description,
167
- examples=examples.values.tolist(),
168
  )
169
  demo.launch(debug=True, show_error=True)
 
1
  import logging
2
+ import os
3
  import pathlib
4
+ import shutil
5
+ import tempfile
6
+ from pathlib import Path
7
 
8
  import gradio as gr
9
  import pandas as pd
10
+ from gt4sd.properties.crystals import CRYSTALS_PROPERTY_PREDICTOR_FACTORY
 
 
 
 
 
 
 
 
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
  logger.addHandler(logging.NullHandler())
14
 
15
+ suffix_dict = {"metal_nonmetal_classifier": ".csv"}
16
 
17
+
18
+ def create_temp_file(path: str) -> str:
19
+ temp_dir = tempfile.gettempdir()
20
+ temp_folder = os.path.join(temp_dir, "gt4sd_crystal")
21
+ os.makedirs(temp_folder, exist_ok=True)
22
+ # Clean up directory
23
+ for i in os.listdir(temp_folder):
24
+ print("Removing", i)
25
+ os.remove(os.path.join(temp_folder, i))
26
+
27
+ temp_path = os.path.join(temp_folder, path.split("/")[-1])
28
+ shutil.copy2(path, temp_path)
29
+ return temp_path
30
+
31
+
32
+ def main(property: str, data_file: str):
33
+
34
+ print(data_file, data_file.orig_name, data_file.name)
35
+
36
+ if data_file is None:
37
+ raise TypeError("You have to pass either an input file for the crystal model")
38
+
39
+ # Copy file into a UNIQUE temporary directory
40
+ file_path = Path(create_temp_file(data_file.name))
41
+ folder = file_path.parent
42
+ print(file_path)
43
+ print(folder)
44
+ if file_path.suffix == ".cif":
45
+ input_path = folder
46
+ elif file_path.suffix == ".csv":
47
+ input_path = file_path
48
+ elif file_path.suffix == ".zip":
49
+ # Unzip zip
50
+ shutil.unpack_archive(file_path, file_path.parent)
51
+ if len(list(filter(lambda x: x.endswith(".cif"), os.listdir(folder)))) == 0:
52
+ raise ValueError("No `.cif` files were found inside the `.zip`.")
53
+ input_path = folder
54
  else:
55
+ raise TypeError(
56
+ "You have to pass a `.csv` (for `metal_nonmetal_classifier`),"
57
+ " a `.cif` (for all other properties) or a `.zip` with multiple"
58
+ f" `.cif` files. Not {type(data_file)}."
59
  )
60
+
61
+ prop_name = property.replace(" ", "_").lower()
62
+ algo, config = CRYSTALS_PROPERTY_PREDICTOR_FACTORY[prop_name]
63
+ # Pass hyperparameters if applicable
64
+ kwargs = {"algorithm_version": "v0"}
65
+ model = algo(config(**kwargs))
66
+
67
+ result = model(input=input_path)
68
+ return pd.DataFrame(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  if __name__ == "__main__":
72
 
73
  # Preparation (retrieve all available algorithms)
74
+ properties = list(CRYSTALS_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
75
+ properties = list(map(lambda x: x.replace("_", " ").title(), properties))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Load metadata
78
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
79
 
80
+ examples = [
81
+ ["Formation Energy", metadata_root.joinpath("7206075.cif")],
82
+ ["Bulk moduli", metadata_root.joinpath("crystals.zip")],
83
+ ["Metal Nonmetal Classifier", metadata_root.joinpath("metal.csv")],
84
+ ["Bulk moduli", metadata_root.joinpath("9000046.cif")],
85
+ ]
86
 
87
+ with open(metadata_root.joinpath("article.md"), "r") as f:
88
  article = f.read()
89
+ with open(metadata_root.joinpath("description.md"), "r") as f:
 
 
90
  description = f.read()
91
 
92
  demo = gr.Interface(
93
+ fn=main,
94
+ title="Crystal properties",
95
  inputs=[
96
+ gr.Dropdown(properties, label="Property", value="Instability"),
97
+ gr.File(
98
+ file_types=[".cif", ".csv", ".zip"],
99
+ label="Input file for crystal model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ),
101
  ],
102
+ outputs=gr.DataFrame(label="Output"),
103
  article=article,
104
  description=description,
105
+ examples=examples,
106
  )
107
  demo.launch(debug=True, show_error=True)
model_cards/7206075.cif ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #------------------------------------------------------------------------------
2
+ #$Date: 2016-03-26 17:23:40 +0200 (Sat, 26 Mar 2016) $
3
+ #$Revision: 180391 $
4
+ #$URL: svn://www.crystallography.net/cod/cif/7/20/60/7206075.cif $
5
+ #------------------------------------------------------------------------------
6
+ #
7
+ # This file is available in the Crystallography Open Database (COD),
8
+ # http://www.crystallography.net/
9
+ #
10
+ # All data on this site have been placed in the public domain by the
11
+ # contributors.
12
+ #
13
+ data_7206075
14
+ loop_
15
+ _publ_author_name
16
+ 'Rezaee, Masih'
17
+ 'Mousavi Khoie, Seyyed Mohammad'
18
+ 'Liu, Kun Hua'
19
+ _publ_section_title
20
+ ;
21
+ The role of brookite in mechanical activation of anatase-to-rutile
22
+ transformation of nanocrystalline TiO2: An XRD and Raman spectroscopy
23
+ investigation
24
+ ;
25
+ _journal_issue 16
26
+ _journal_name_full CrystEngComm
27
+ _journal_page_first 5055
28
+ _journal_paper_doi 10.1039/c1ce05185g
29
+ _journal_volume 13
30
+ _journal_year 2011
31
+ _chemical_formula_structural 'Ti O2'
32
+ _chemical_formula_sum 'O2 Ti'
33
+ _chemical_name_mineral Anatase
34
+ _chemical_name_systematic 'Titanium oxide'
35
+ _space_group_IT_number 141
36
+ _symmetry_cell_setting tetragonal
37
+ _symmetry_Int_Tables_number 141
38
+ _symmetry_space_group_name_Hall 'I 4bw 2bw -1bw'
39
+ _symmetry_space_group_name_H-M 'I 41/a m d :1'
40
+ _audit_update_record
41
+ ;
42
+ 2011-02-06 # Formatted by publCIF
43
+ ;
44
+ _cell_angle_alpha 90
45
+ _cell_angle_beta 90
46
+ _cell_angle_gamma 90
47
+ _cell_formula_units_Z 4
48
+ _cell_length_a 3.7850
49
+ _cell_length_b 3.7850
50
+ _cell_length_c 9.5196
51
+ _cell_measurement_temperature 298
52
+ _cell_volume 136.380
53
+ _computing_cell_refinement MAUD
54
+ _computing_data_collection X'Pert
55
+ _computing_data_reduction MAUD
56
+ _computing_publication_material publCIF
57
+ _computing_structure_refinement MAUD
58
+ _computing_structure_solution MAUD
59
+ _diffrn_measurement_device_type 'GBC MMA X-ray diffractometer'
60
+ _diffrn_radiation_source 'Cu K\a'
61
+ _diffrn_radiation_type 'Cu K\a'
62
+ _diffrn_radiation_wavelength 1.541874
63
+ _diffrn_reflns_theta_max 50
64
+ _cod_data_source_file 400a_anatase_.txt
65
+ _cod_data_source_block 11A
66
+ _cod_original_sg_symbol_H-M 'I 41/a m d S'
67
+ _cod_database_code 7206075
68
+ loop_
69
+ _symmetry_equiv_pos_as_xyz
70
+ x,y,z
71
+ -x,-y,z
72
+ x,1/2+y,1/4-z
73
+ -x,1/2-y,1/4-z
74
+ -x,y,z
75
+ x,-y,z
76
+ -x,1/2+y,1/4-z
77
+ x,1/2-y,1/4-z
78
+ y,x,-z
79
+ -y,-x,-z
80
+ y,1/2+x,1/4+z
81
+ -y,1/2-x,1/4+z
82
+ -y,x,-z
83
+ y,-x,-z
84
+ -y,1/2+x,1/4+z
85
+ y,1/2-x,1/4+z
86
+ 1/2+x,1/2+y,1/2+z
87
+ 1/2-x,1/2-y,1/2+z
88
+ 1/2+x,y,3/4-z
89
+ 1/2-x,-y,3/4-z
90
+ 1/2-x,1/2+y,1/2+z
91
+ 1/2+x,1/2-y,1/2+z
92
+ 1/2-x,y,3/4-z
93
+ 1/2+x,-y,3/4-z
94
+ 1/2+y,1/2+x,1/2-z
95
+ 1/2-y,1/2-x,1/2-z
96
+ 1/2+y,x,3/4+z
97
+ 1/2-y,-x,3/4+z
98
+ 1/2-y,1/2+x,1/2-z
99
+ 1/2+y,1/2-x,1/2-z
100
+ 1/2-y,x,3/4+z
101
+ 1/2+y,-x,3/4+z
102
+ loop_
103
+ _atom_site_label
104
+ _atom_site_type_symbol
105
+ _atom_site_symmetry_multiplicity
106
+ _atom_site_Wyckoff_symbol
107
+ _atom_site_fract_x
108
+ _atom_site_fract_y
109
+ _atom_site_fract_z
110
+ _atom_site_occupancy
111
+ _atom_site_attached_hydrogens
112
+ _atom_site_calc_flag
113
+ Ti1 Ti4+ 4 a 0. 0. 0. 1. 0 d
114
+ O1 O2- 8 e 0. 0. 0.21017 1. 0 d
115
+ loop_
116
+ _atom_type_symbol
117
+ _atom_type_oxidation_number
118
+ Ti4+ 4.000
119
+ O2- -2.000
model_cards/9000046.cif ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #------------------------------------------------------------------------------
2
+ #$Date: 2013-05-05 17:21:46 +0300 (Sun, 05 May 2013) $
3
+ #$Revision: 85285 $
4
+ #$URL: svn://www.crystallography.net/cod/cif/9/00/00/9000046.cif $
5
+ #------------------------------------------------------------------------------
6
+ #
7
+ # This file is available in the Crystallography Open Database (COD),
8
+ # http://www.crystallography.net/. The original data for this entry
9
+ # were provided the American Mineralogist Crystal Structure Database,
10
+ # http://rruff.geo.arizona.edu/AMS/amcsd.php
11
+ #
12
+ # The file may be used within the scientific community so long as
13
+ # proper attribution is given to the journal article from which the
14
+ # data were obtained.
15
+ #
16
+ data_9000046
17
+ loop_
18
+ _publ_author_name
19
+ 'Kukesh, J. S.'
20
+ 'Pauling, L.'
21
+ _publ_section_title
22
+ ;
23
+ The problem of the graphite structure
24
+ ;
25
+ _journal_name_full 'American Mineralogist'
26
+ _journal_page_first 125
27
+ _journal_page_last 125
28
+ _journal_volume 35
29
+ _journal_year 1950
30
+ _chemical_formula_sum C
31
+ _chemical_name_common Graphite
32
+ _chemical_name_mineral Graphite
33
+ _space_group_IT_number 69
34
+ _symmetry_space_group_name_Hall '-F 2 2'
35
+ _symmetry_space_group_name_H-M 'F m m m'
36
+ _cell_angle_alpha 90
37
+ _cell_angle_beta 90
38
+ _cell_angle_gamma 90
39
+ _cell_length_a 2.456
40
+ _cell_length_b 4.254
41
+ _cell_length_c 6.696
42
+ _cell_volume 69.959
43
+ _exptl_crystal_density_diffrn 2.281
44
+ _cod_database_code 9000046
45
+ loop_
46
+ _symmetry_equiv_pos_as_xyz
47
+ x,y,z
48
+ x,1/2+y,1/2+z
49
+ 1/2+x,y,1/2+z
50
+ 1/2+x,1/2+y,z
51
+ x,-y,z
52
+ x,1/2-y,1/2+z
53
+ 1/2+x,-y,1/2+z
54
+ 1/2+x,1/2-y,z
55
+ -x,y,-z
56
+ -x,1/2+y,1/2-z
57
+ 1/2-x,y,1/2-z
58
+ 1/2-x,1/2+y,-z
59
+ -x,y,z
60
+ -x,1/2+y,1/2+z
61
+ 1/2-x,y,1/2+z
62
+ 1/2-x,1/2+y,z
63
+ x,-y,-z
64
+ x,1/2-y,1/2-z
65
+ 1/2+x,-y,1/2-z
66
+ 1/2+x,1/2-y,-z
67
+ x,y,-z
68
+ x,1/2+y,1/2-z
69
+ 1/2+x,y,1/2-z
70
+ 1/2+x,1/2+y,-z
71
+ -x,-y,z
72
+ -x,1/2-y,1/2+z
73
+ 1/2-x,-y,1/2+z
74
+ 1/2-x,1/2-y,z
75
+ -x,-y,-z
76
+ -x,1/2-y,1/2-z
77
+ 1/2-x,-y,1/2-z
78
+ 1/2-x,1/2-y,-z
79
+ loop_
80
+ _atom_site_label
81
+ _atom_site_fract_x
82
+ _atom_site_fract_y
83
+ _atom_site_fract_z
84
+ C 0.00000 0.16667 0.00000
model_cards/{regression_transformer_article.md → article.md} RENAMED
File without changes
model_cards/crystals.zip ADDED
Binary file (3.07 kB). View file
 
model_cards/description.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <img align="right" src="https://raw.githubusercontent.com/GT4SD/gt4sd-core/main/docs/_static/gt4sd_logo.png" alt="logo" width="120" >
4
+
5
+ ### Crystal property prediction
6
+
7
+ This is the GT4SD web-app for prediction of various crystal properties. For **examples** and **documentation** of the supported properties, please see below.
8
+ Enjoy :)
model_cards/metal.csv ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ KPSO2,orthorhombic
2
+ Zr2Ga(PO4)3,trigonal
3
+ Te4Mo(WSe)2,trigonal
4
+ Mo3W(SeS3)2,trigonal
5
+ Te2Mo2SeS,trigonal
6
+ Mo3W(Se3S)2,trigonal
7
+ MoWSe3S,trigonal
8
+ Mo3W(SeS)4,trigonal
9
+ Te2Mo3WSe6,trigonal
10
+ TeW2SeS2,trigonal
11
+ Te4MoW3S4,trigonal
12
+ Te6Mo3WS2,trigonal
13
+ KMg6CO8,tetragonal
14
+ Mg14BiBO16,orthorhombic
15
+ KMg14WO16,tetragonal
16
+ Mg14AlCdO16,orthorhombic
17
+ Mg30VCrO32,tetragonal
18
+ Mg30CoSiO32,tetragonal
19
+ YMg30CO32,tetragonal
20
+ KYMg30O32,tetragonal
21
+ CaMg30MnO32,tetragonal
22
+ CaMg30CO32,tetragonal
23
+ LiMg30MnO32,tetragonal
24
+ CaMg30NiO32,tetragonal
25
+ LiMg30AlO32,tetragonal
26
+ Mg30AlFeO32,tetragonal
27
+ RbMg30SbO32,tetragonal
28
+ KNaMg30O3orthorhombic
29
+ La7Sm(Fe2O5)4,triclinic
30
+ SrCa3Mn4O1triclinic
31
+ NbNi3(HC)2,tetragonal
32
+ La2P2AuO,monoclinic
33
+ Li9Mn2Co5O16,monoclinic
34
+ Li9Mn2Co5O16,monoclinic
35
+ Li9Mn2Co5O16,monoclinic
36
+ Li9Mn2Co5O16,monoclinic
37
+ Li9Mn2Co5O16,monoclinic
38
+ LiCrP4O13,triclinic
39
+ LiCr4P7O24,triclinic
40
+ ZnGe(OF)6,trigonal
41
+ Cs2Mo(SO)2,monoclinic
42
+ NaMgSO7,monoclinic
43
+ K2NaNdCl6,cubic
44
+ K2NaBiCl6,cubic
45
+ Na2EuCuCl6,cubic
46
+ NaLi2CoF6,cubic
47
+ K2NaTiF6,cubic
48
+ K2AgRhF6,cubic
49
+ K2CeAgCl6,cubic
50
+ K2ErCuCl6,cubic
model_cards/regression_transformer.png DELETED
Binary file (225 kB)
 
model_cards/regression_transformer_description.md DELETED
@@ -1,13 +0,0 @@
1
-
2
-
3
- <img align="right" src="https://raw.githubusercontent.com/GT4SD/gt4sd-core/main/docs/_static/gt4sd_logo.png" alt="logo" width="120" >
4
-
5
- ### Concurrent sequence regression and generation for molecular language modeling
6
-
7
- The [Regression Transformer](https://arxiv.org/abs/2202.01338) is a multitask Transformer that reformulates regression as a conditional sequence modeling task.
8
- This yields a dichotomous language model that seamlessly integrates property prediction with property-driven conditional generation. For details see the [arXiv preprint](https://arxiv.org/abs/2202.01338), the [development code](https://github.com/IBM/regression-transformer) and the [GT4SD endpoint](https://github.com/GT4SD/gt4sd-core) for inference.
9
-
10
- Each `algorithm_version` refers to one trained model. Each model can be used for **two tasks**, either to *predict* one (or multiple) properties of a molecule or to *generate* a molecule (given a seed molecule and a property constraint).
11
-
12
- For **examples** and **documentation** of the model parameters, please see below.
13
- Moreover, we provide a **model card** ([Mitchell et al. (2019)](https://dl.acm.org/doi/abs/10.1145/3287560.3287596?casa_token=XD4eHiE2cRUAAAAA:NL11gMa1hGPOUKTAbtXnbVQBDBbjxwcjGECF_i-WC_3g1aBgU1Hbz_f2b4kI_m1in-w__1ztGeHnwHs)) at the bottom of this page.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_cards/regression_transformer_examples.csv DELETED
@@ -1,9 +0,0 @@
1
- Molecules: Logp_and_synthesizability,Generate,CCOC1=NC=NC(=C1C)NCCOC(C)C,3,Sample,1.2,20,True,0.3,"<logp>:0.390, <scs>:2.628",N,(C)C,CCO
2
- Molecules: Qed,Generate,CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1,10,Sample,1.0,30,True,0.5,<qed>:0.75,"N, C","C(=O), CC",C1=CC=C(Cl)C=C1
3
- Molecules: Logp_and_synthesizability,Predict,<logp>[MASK][MASK][MASK][MASK][MASK]|<scs>[MASK][MASK][MASK][MASK][MASK]|[C][C][O][C][=N][C][=N][C][Branch1_2][Branch1_1][=C][Ring1][Branch1_2][C][N][C][C][O][C][Branch1_1][C][C][C],1,Greedy,1.0,30,False,0.0,,,,
4
- Proteins: Stability,Predict,<stab>[MASK][MASK][MASK][MASK][MASK]|GSQEVNSGTQTYKNASPEEAERIARKAGATTWTEKGNKWEIRI,1,Greedy,1.0,1,False,0.0,,,,
5
- Proteins: Stability,Generate,GSQEVNSGTQTYKNASPEEAERIARKAGATTWTEKGNKWEIRI,10,Sample,1.2,30,True,0.3,<stab>:0.393,,SQEVNSGTQTYKN,WTEK
6
- Molecules: Qed,Generate,<qed>0.717|[MASK][MASK][MASK][MASK][MASK][C][Branch2_1][Ring1][Ring1][MASK][MASK][=C][C][Branch1_1][C][C][=N][C][MASK][MASK][=C][C][=C][Ring1][O][Ring1][Branch1_2][=C][Ring2][MASK][MASK],10,Sample,1.2,30,False,0.0,,,,
7
- Molecules: Solubility,Generate,ClC(Cl)C(Cl)Cl,5,Sample,1.3,40,True,0.4,<esol>:0.754,,,
8
- Molecules: Polymer,Predict,<conv>[MASK][MASK][MASK][MASK]|<pdi>[MASK][MASK][MASK][MASK][MASK]|<molwt>[MASK][MASK][MASK][MASK][MASK]|[C][Branch1_2][C][=O][O][C@@Hexpl][Branch1_1][C][C][C][Branch1_2][C][=O][O][C@Hexpl][Ring1][Branch2_2][C].[C][C][C][Branch2_1][Ring1][Ring1][N][C][Branch1_1][=C][N][C][=C][C][=C][Branch1_1][Ring1][O][C][C][=C][Ring1][Branch2_1][=S][C][C][C][Ring2][Ring1][C],1,Greedy,1,0,False,,,,,
9
- Molecules: Polymer,Generate,C1(=O)O[C@@H](C)C(=O)O[C@H]1C.C2CC(NC(NC1=CC=C(OC)C=C1)=S)CCC2,10,Sample,1.3,50,True,0.5,"<pdi>:3.490, <conv>:0.567, <molwt>:3.567",,,C1(=O)O[C@@H](C)C(=O)O[C@H]1C
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -8,7 +8,7 @@ torch-sparse
8
  torch-geometric
9
  torchvision==0.13.1
10
  torchaudio==0.12.1
11
- gt4sd>=1.0.6
12
  molgx>=0.22.0a1
13
  molecule_generation
14
  nglview
 
8
  torch-geometric
9
  torchvision==0.13.1
10
  torchaudio==0.12.1
11
+ gt4sd>=1.1.4
12
  molgx>=0.22.0a1
13
  molecule_generation
14
  nglview
utils.py CHANGED
@@ -1,16 +1,7 @@
1
- import json
2
  import logging
3
- import os
4
- from collections import defaultdict
5
- from typing import Dict, List, Tuple
6
 
7
  import mols2grid
8
  import pandas as pd
9
- from gt4sd.algorithms import (
10
- RegressionTransformerMolecules,
11
- RegressionTransformerProteins,
12
- )
13
- from gt4sd.algorithms.core import AlgorithmConfiguration
14
  from rdkit import Chem
15
  from terminator.selfies import decoder
16
 
@@ -18,63 +9,6 @@ logger = logging.getLogger(__name__)
18
  logger.addHandler(logging.NullHandler())
19
 
20
 
21
- def get_application(application: str) -> AlgorithmConfiguration:
22
- """
23
- Convert application name to AlgorithmConfiguration.
24
-
25
- Args:
26
- application: Molecules or Proteins
27
-
28
- Returns:
29
- The corresponding AlgorithmConfiguration
30
- """
31
- if application == "Molecules":
32
- application = RegressionTransformerMolecules
33
- elif application == "Proteins":
34
- application = RegressionTransformerProteins
35
- else:
36
- raise ValueError(
37
- "Currently only models for molecules and proteins are supported"
38
- )
39
- return application
40
-
41
-
42
- def get_inference_dict(
43
- application: AlgorithmConfiguration, algorithm_version: str
44
- ) -> Dict:
45
- """
46
- Get inference dictionary for a given application and algorithm version.
47
-
48
- Args:
49
- application: algorithm application (Molecules or Proteins)
50
- algorithm_version: algorithm version (e.g. qed)
51
-
52
- Returns:
53
- A dictionary with the inference parameters.
54
- """
55
- config = application(algorithm_version=algorithm_version)
56
- with open(os.path.join(config.ensure_artifacts(), "inference.json"), "r") as f:
57
- data = json.load(f)
58
- return data
59
-
60
-
61
- def get_rt_name(x: Dict) -> str:
62
- """
63
- Get the UI display name of the regression transformer.
64
-
65
- Args:
66
- x: dictionary with the inference parameters
67
-
68
- Returns:
69
- The display name
70
- """
71
- return (
72
- x["algorithm_application"].split("Transformer")[-1]
73
- + ": "
74
- + x["algorithm_version"].capitalize()
75
- )
76
-
77
-
78
  def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
79
  """
80
  Uses mols2grid to draw a HTML grid for the prediction
@@ -118,55 +52,3 @@ def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
118
  size=(600, 700),
119
  )
120
  return obj.data
121
-
122
-
123
- def draw_grid_generate(
124
- samples: List[Tuple[str]], domain: str, n_cols: int = 5, size=(140, 200)
125
- ) -> str:
126
- """
127
- Uses mols2grid to draw a HTML grid for the generated molecules
128
-
129
- Args:
130
- samples: The generated samples (with properties)
131
- domain: Domain of the prediction (molecules or proteins)
132
- n_cols: Number of columns in grid. Defaults to 5.
133
- size: Size of molecule in grid. Defaults to (140, 200).
134
-
135
- Returns:
136
- HTML to display
137
- """
138
-
139
- if domain not in ["Molecules", "Proteins"]:
140
- raise ValueError(f"Unsupported domain {domain}")
141
-
142
- if domain == "Proteins":
143
- try:
144
- smis = list(
145
- map(lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x[0])), samples)
146
- )
147
- except Exception:
148
- logger.warning(f"Could not convert some sequences {samples}")
149
- else:
150
- smis = [s[0] for s in samples]
151
-
152
- result = defaultdict(list)
153
- result.update({"SMILES": smis, "Name": [f"sample_{i}" for i in range(len(smis))]})
154
-
155
- # Create properties
156
- properties = [s.split("<")[1] for s in samples[0][1].split(">")[:-1]]
157
- # Fill properties
158
- for sample in samples:
159
- for prop in properties:
160
- value = float(sample[1].split(prop)[-1][1:].split("<")[0])
161
- result[prop].append(f"{prop} = {value}")
162
-
163
- result_df = pd.DataFrame(result)
164
- obj = mols2grid.display(
165
- result_df,
166
- tooltip=list(result.keys()),
167
- height=1100,
168
- n_cols=n_cols,
169
- name="Results",
170
- size=size,
171
- )
172
- return obj.data
 
 
1
  import logging
 
 
 
2
 
3
  import mols2grid
4
  import pandas as pd
 
 
 
 
 
5
  from rdkit import Chem
6
  from terminator.selfies import decoder
7
 
 
9
  logger.addHandler(logging.NullHandler())
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
13
  """
14
  Uses mols2grid to draw a HTML grid for the prediction
 
52
  size=(600, 700),
53
  )
54
  return obj.data