maykcaldas commited on
Commit
77cbf82
1 Parent(s): 1a126f0

Upload 7 files

Browse files
Files changed (4) hide show
  1. agent.py +4 -0
  2. mapi_tools.py +1 -1
  3. reaction_prediction.py +150 -0
  4. utils.py +1 -0
agent.py CHANGED
@@ -1,5 +1,6 @@
1
  from mapi_tools import MAPI_class_tools, MAPI_reg_tools
2
  from utils import common_tools
 
3
  from langchain import OpenAI
4
  from gpt_index import GPTListIndex, GPTIndexMemory
5
  from langchain import agents
@@ -44,6 +45,8 @@ ionic_energy = MAPI_reg_tools(
44
  total_energy = MAPI_reg_tools(
45
  "e_total","total energy"
46
  )
 
 
47
 
48
  class Agent:
49
  def __init__(self, openai_api_key, mapi_api_key):
@@ -63,6 +66,7 @@ class Agent:
63
  electronic_energy.get_tools() +
64
  ionic_energy.get_tools() +
65
  total_energy.get_tools() +
 
66
  agents.load_tools(["llm-math", "python_repl"], llm=llm) +
67
  common_tools
68
  )
 
1
  from mapi_tools import MAPI_class_tools, MAPI_reg_tools
2
  from utils import common_tools
3
+ from reaction_prediction import SynthesisReactions
4
  from langchain import OpenAI
5
  from gpt_index import GPTListIndex, GPTIndexMemory
6
  from langchain import agents
 
45
  total_energy = MAPI_reg_tools(
46
  "e_total","total energy"
47
  )
48
+ reaction = SynthesisReactions()
49
+
50
 
51
  class Agent:
52
  def __init__(self, openai_api_key, mapi_api_key):
 
66
  electronic_energy.get_tools() +
67
  ionic_energy.get_tools() +
68
  total_energy.get_tools() +
69
+ reaction.get_tools() +
70
  agents.load_tools(["llm-math", "python_repl"], llm=llm) +
71
  common_tools
72
  )
mapi_tools.py CHANGED
@@ -212,4 +212,4 @@ class MAPI_reg_tools(MAPITools):
212
  suffix=suffix,
213
  input_variables=["formula"])
214
 
215
- return prompt.format(formula=formula)
 
212
  suffix=suffix,
213
  input_variables=["formula"])
214
 
215
+ return prompt.format(formula=formula)
reaction_prediction.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from langchain.agents import Tool, tool
4
+ # from mp_api.client import MPRester
5
+ from pymatgen.ext.matproj import MPRester
6
+ from rxn_network.entries.entry_set import GibbsEntrySet
7
+ from rxn_network.enumerators.basic import BasicEnumerator
8
+
9
+ class SynthesisReactions:
10
+ def __init__(self, temp=900, stabl=0.025, exclusive_precursors=False, exclusive_targets=False):
11
+ self.temp = temp
12
+ self.stabl = stabl
13
+ self.exclusive_precursors = exclusive_precursors
14
+ self.exclusive_targets = exclusive_targets
15
+
16
+ def _split_string(self, s):
17
+ if isinstance(s, list):
18
+ s = "".join(s)
19
+ parts = re.findall('[a-z]+|[A-Z][a-z]*', s)
20
+ letters_only = [re.sub(r'\d+', '', part) for part in parts]
21
+ unique_letters = list(set(letters_only))
22
+ result = "-".join(unique_letters)
23
+ return result
24
+
25
+ def _get_rxn_from_precursor(self, precursors_formulas):
26
+ prec = precursors_formulas.split(',') if "," in precursors_formulas else precursors_formulas
27
+
28
+ with MPRester(os.getenv("MAPI_API_KEY")) as mpr:
29
+ entries = mpr.get_entries_in_chemsys(self._split_string(prec))
30
+
31
+ gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp)
32
+ filtered_entries = gibbs_entries.filter_by_stability(self.stabl)
33
+
34
+ prec = [prec] if isinstance(prec, str) else prec
35
+ be = BasicEnumerator(precursors=prec, exclusive_precursors=self.exclusive_precursors)
36
+ rxns = be.enumerate(filtered_entries)
37
+ try:
38
+ rxn_choice = next(iter(rxns))
39
+ return str(rxn_choice)
40
+ except:
41
+ return "Error: No reactions found."
42
+
43
+ def _get_rxn_from_target(self, targets_formulas):
44
+ targets = targets_formulas.split(',') if "," in targets_formulas else targets_formulas
45
+
46
+ with MPRester(os.getenv("MAPI_API_KEY")) as mpr:
47
+ entries = mpr.get_entries_in_chemsys(self._split_string(targets))
48
+
49
+ gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp)
50
+ filtered_entries = gibbs_entries.filter_by_stability(self.stabl)
51
+
52
+ targets = [targets] if isinstance(targets, str) else targets
53
+
54
+ be = BasicEnumerator(targets=targets, exclusive_targets=self.exclusive_targets)
55
+ rxns = be.enumerate(filtered_entries)
56
+ try:
57
+ rxn_choice = next(iter(rxns))
58
+ return str(rxn_choice)
59
+ except:
60
+ return "Error: No reactions found."
61
+
62
+ def _break_equation(self, equation):
63
+ pattern = r'(\d*\.?\d*\s*[A-Za-z]+\d*|\+|\->)'
64
+ pieces = re.findall(pattern, equation)
65
+ equation_pieces = []
66
+ current_piece = ''
67
+ for piece in pieces:
68
+ if piece == '+' or piece == '->':
69
+ equation_pieces.append(current_piece.strip())
70
+ equation_pieces.append(piece)
71
+ current_piece = ''
72
+ else:
73
+ current_piece += piece + ' '
74
+ equation_pieces.append(current_piece.strip())
75
+ return equation_pieces
76
+
77
+ def _convert_equation_pieces(self, equation_pieces):
78
+ if '+' in equation_pieces:
79
+ equation_pieces = [piece if piece != '+' else 'with' for piece in equation_pieces]
80
+ equation_pieces = [piece if piece != '->' else 'to yield' for piece in equation_pieces]
81
+ else:
82
+ equation_pieces = [piece if piece != '->' else 'yields' for piece in equation_pieces]
83
+ return equation_pieces
84
+
85
+ def _split_equation_pieces(self, equation_pieces):
86
+ new_pieces = []
87
+ for piece in equation_pieces:
88
+ if piece in ["with", "to yield", "yields"]:
89
+ new_pieces.append(piece)
90
+ else:
91
+ if re.match(r'^\d*\.\d+|\d+', piece):
92
+ number_match = re.match(r'^\d*\.\d+|\d+', piece)
93
+ number = number_match.group(0)
94
+ rest = piece[len(number):]
95
+ new_pieces.append(number)
96
+ new_pieces.append(rest)
97
+ else:
98
+ new_pieces.append("1")
99
+ new_pieces.append(piece)
100
+ return new_pieces
101
+
102
+ def _modify_mols(self, equation_pieces):
103
+ for i, piece in enumerate(equation_pieces):
104
+ if piece.replace('.', '', 1).isdigit():
105
+ equation_pieces[i] = f"{piece} mols"
106
+ return equation_pieces
107
+
108
+ def _combine_equation_pieces(self, equation_pieces):
109
+ if 'with' in equation_pieces:
110
+ equation_pieces.insert(0, 'mix')
111
+ combined_string = ' '.join(equation_pieces)
112
+ return combined_string
113
+
114
+ def _process_equation(self, equation):
115
+ equation_pieces = self._break_equation(equation)
116
+ converted_pieces = self._convert_equation_pieces(equation_pieces)
117
+ split_pieces = self._split_equation_pieces(converted_pieces)
118
+ modified_pieces = self._modify_mols(split_pieces)
119
+ combined_string = self._combine_equation_pieces(modified_pieces)
120
+ return combined_string
121
+
122
+ def get_reaction(self, input_string):
123
+ input_parts = input_string.split(',', 1)
124
+ if len(input_parts) != 2:
125
+ raise ValueError("Invalid input format. Expected 'precursor' or 'target', followed by a comma, and then the list of formulas separated by a comma.")
126
+
127
+ mode, formulas = input_parts
128
+ mode = mode.lower().strip()
129
+
130
+ if mode == "precursor":
131
+ reaction = self._get_rxn_from_precursor(formulas)
132
+ elif mode == "target":
133
+ reaction = self._get_rxn_from_target(formulas)
134
+ else:
135
+ raise ValueError("Invalid mode. Expected 'precursor' or 'target'.")
136
+ processed_reaction = self._process_equation(reaction)
137
+ return processed_reaction
138
+
139
+ def get_tools(self):
140
+ return [
141
+ Tool(
142
+ name = "Get a synthesis reaction for a material",
143
+ func = self.get_reaction,
144
+ description = (
145
+ "This function is useful for suggesting a synthesis reaction for a material. "
146
+ "Give this tool a string containing either precursor or target, then a comma, followed by the formulas separated by comma as input and returns a synthesis reaction."
147
+ "The mode is used to determine if the input is a precursor or a target material. "
148
+ )
149
+ )]
150
+
utils.py CHANGED
@@ -2,6 +2,7 @@ from langchain.agents import Tool, tool
2
  import requests
3
  from langchain import OpenAI
4
  from langchain import LLMMathChain, SerpAPIWrapper
 
5
  from rdkit import Chem
6
 
7
  @tool
 
2
  import requests
3
  from langchain import OpenAI
4
  from langchain import LLMMathChain, SerpAPIWrapper
5
+ import os
6
  from rdkit import Chem
7
 
8
  @tool