from aiflows.base_flows import CompositeFlow from aiflows.utils import logging from aiflows.interfaces import KeyInterface from aiflows.messages import FlowMessage from typing import Dict, Any log = logging.get_logger(f"aiflows.{__name__}") class FunSearch(CompositeFlow): """ This class implements FunSearch. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch) . It's a Flow in charge of starting, stopping and managing (passing around messages) the FunSearch process. It passes messages around to the following subflows: - ProgramDBFlow: which is in charge of storing and retrieving programs. - SamplerFlow: which is in charge of sampling programs. - EvaluatorFlow: which is in charge of evaluating programs. *Configuration Parameters*: - `name` (str): The name of the flow. Default: "FunSearchFlow". - `description` (str): The description of the flow. Default: "A flow implementing FunSearch" - `subflows_config` (Dict[str,Any]): A dictionary of subflows configurations. Default: - `ProgramDBFlow`: By default, it uses the `ProgramDBFlow` class and uses its default parameters. - `SamplerFlow`: By default, it uses the `SamplerFlow` class and uses its default parameters. - `EvaluatorFlow`: By default, it uses the `EvaluatorFlow` class and uses its default parameters. **Input Interface**: - `from` (str): The flow from which the message is coming from. It can be one of the following: "FunSearch", "SamplerFlow", "EvaluatorFlow", "ProgramDBFlow". - `operation` (str): The operation to perform. It can be one of the following: "start", "stop", "get_prompt", "get_best_programs_per_island", "register_program". - `content` (Dict[str,Any]): The content associated to an operation. Here is the expected content for each operation: - "start": - `num_samplers` (int): The number of samplers to start up. Note that it's still restricted by the number of workers available. Default: 1. - "stop": - No content. Pass either an empty dictionary or None. Works also with no content. - "get_prompt": - No content. Pass either an empty dictionary or None. Works also with no content. - "get_best_programs_per_island": - No content. Pass either an empty dictionary or None. Works also with no content. **Output Interface**: - `retrieved` (Dict[str,Any]): The retrieved data. **Citation**: @Article{FunSearch2023, author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein}, journal = {Nature}, title = {Mathematical discoveries from program search with large language models}, year = {2023}, doi = {10.1038/s41586-023-06924-6} } """ def __init__(self, **kwargs): super().__init__(**kwargs) #next state per action #this is a dictionary that maps the next state of the flow based on the action and the current state self.next_state_per_action = { "get_prompt": { "FunSearch": "ProgramDBFlow", "ProgramDBFlow": "SamplerFlow", }, "get_best_programs_per_island": { "FunSearch": "ProgramDBFlow", "ProgramDBFlow": "GenerateReply", }, "register_program": { "SamplerFlow": "EvaluatorFlow", "EvaluatorFlow": "ProgramDBFlow", }, "start": {"FunSearch": "FunSearch"}, "stop": {"FunSearch": "FunSearch"}, } #key interface to make a request for a prompt self.make_request_for_prompt_data = KeyInterface( keys_to_set= {"operation": "get_prompt", "content": {}, "from": "FunSearch"}, keys_to_select= ["operation", "content", "from"] ) def make_request_for_prompt(self): """ This method makes a request for a prompt. It sends a message to itself with the operation "get_prompt" which will trigger the flow to call the `ProgramDBFlow` to get a prompt. """ #Prepare data to make request for prompt data = self.make_request_for_prompt_data({}) #Package message to make request for prompt msg = self.package_input_message( data=data, dst_flow="FunSearch" ) #Send message to itself to start the process of getting a prompt self.send_message( msg ) def request_samplers(self,input_message: FlowMessage): """ This method requests samplers. It sends a message to itself with the operation "get_prompt" which will trigger the flow to call the `ProgramDBFlow` to get a prompt. :param input_message: The input message that triggered the request for samplers. :type input_message: FlowMessage """ #Get state associated with the message message_state = self.pop_message_from_state(input_message.input_message_id) #Get number of samplers to request num_samplers = message_state["content"].get("num_samplers",1) for i in range(num_samplers): self.make_request_for_prompt() def get_next_state(self, input_message: FlowMessage): """ This method determines the next state of the flow based on the input message. It will return the next state based on the current state and the message received. :param input_message: The input message that triggered the request for the next state. :type input_message: FlowMessage :return: The next state of the flow. :rtype: str """ #Get state associated with the message message_state = self.get_message_from_state(input_message.input_message_id) message_from = message_state["from"] operation = message_state["operation"] #Get next state based on the action and the current state next_state = self.next_state_per_action[operation][message_from] return next_state def set_up_flow_state(self): """ This method sets up the state of the flow. It's called at the beginning of the flow.""" super().set_up_flow_state() #Dictonary containing state of message currently being handled by FunSearch #Each message has its own state in the flow state #Once a message is done being handled, it's removed from the state self.flow_state["msg_requests"] = {} #Flag to keep track if the first sample has been saved to the db self.flow_state["first_sample_saved_to_db"] = False #Flag to keep track if FunSearch is running self.flow_state["funsearch_running"] = False def save_message_to_state(self,msg_id: str, message: FlowMessage): """ This method saves a message to the state of the flow. It's used to keep track of state on a per message basis (i.e., state of the flow depending on the message received and id). :param msg_id: The id of the message to save. :type msg_id: str :param message: The message to save. :type message: FlowMessage """ self.flow_state["msg_requests"][msg_id] = {"og_message": message} def rename_key_message_in_state(self, old_key: str, new_key: str): """ This method renames a key in the state of the flow in the "msg_requests" dictonary. It's used to rename a key in the state of the flow (i.e., rename a message id). :param old_key: The old key to rename. :type old_key: str :param new_key: The new key to rename to. :type new_key: str """ self.flow_state["msg_requests"][new_key] = self.flow_state["msg_requests"].pop(old_key) def message_in_state(self,msg_id: str) -> bool: """ This method checks if a message is in the state of the flow (in "msg_requests" dictionary). It returns True if the message is in the state, otherwise it returns False. :param msg_id: The id of the message to check if it's in the state. :type msg_id: str :return: True if the message is in the state, otherwise False. :rtype: bool """ return msg_id in self.flow_state["msg_requests"].keys() def get_message_from_state(self, msg_id: str) -> Dict[str,Any]: """ This method returns the state associated with a message id in the state of the flow (in "msg_requests" dictionary). :param msg_id: The id of the message to get the state from. :type msg_id: str :return: The state associated with the message id. :rtype: Dict[str,Any] """ return self.flow_state["msg_requests"][msg_id] def pop_message_from_state(self, msg_id: str) -> Dict[str,Any]: """ This method pops a message from the state of the flow (in "msg_requests" dictionary). It the state associate to a message and removes it from the state. :param msg_id: The id of the message to pop from the state. :type msg_id: str :return: The state associated with the message id. :rtype: Dict[str,Any] """ return self.flow_state["msg_requests"].pop(msg_id) def merge_message_request_state(self,id: str, new_states: Dict[str,Any]): """ This method merges new states to a message in the state of the flow (in "msg_requests" dictionary). It merges new states to a message in the state. :param id: The id of the message to merge new states to. :type id: str :param new_states: The new states to merge to the message. :type new_states: Dict[str,Any] """ self.flow_state["msg_requests"][id] = {**self.flow_state["msg_requests"][id], **new_states} def register_data_to_state(self, input_message: FlowMessage): """This method registers the input message data to the flow state. It's called everytime a new input message is received. :param input_message: The input message :type input_message: FlowMessage """ #Determine Who the message is from (should be either FunSearch, SamplerFlow, EvaluatorFlow, or ProgramDBFlow) msg_from = input_message.data.get("from", "FunSearch") #Check if this a first request or part of a message that is being handled (it's part of message being handled if message is in the state) msg_id = input_message.input_message_id msg_in_state = self.message_in_state(msg_id) #If message is not in state, save it to state if not msg_in_state: self.save_message_to_state(msg_id, input_message) #Get the state associated to the message message_state = self.get_message_from_state(msg_id) #Determine what to do based on who the message is from if msg_from == "FunSearch": #Calls From FunSearch expect operation and content operation = input_message.data["operation"] content = input_message.data.get("content",{}) to_add_to_state = { "content": content, "operation": operation } #save operation and content to state self.merge_message_request_state(msg_id, to_add_to_state) elif msg_from == "SamplerFlow": #Calls From SamplerFlow expect api_output, merge it to state to_add_to_state = { "content": { **message_state.get("content",{}), **{"artifact": input_message.data["api_output"]} }, "operation": "register_program" } self.merge_message_request_state(msg_id, to_add_to_state) elif msg_from == "EvaluatorFlow": #Calls From EvaluatorFlow expect scores_per_test, merge it to state message_state = self.get_message_from_state(msg_id) to_add_to_state = { "content": { **message_state.get("content",{}), **{"scores_per_test": input_message.data["scores_per_test"]} } } self.merge_message_request_state(msg_id, to_add_to_state) elif msg_from == "ProgramDBFlow": #Calls From ProgramDBFlow expect retrieved, merge it to state to_add_to_state = { "retrieved": input_message.data["retrieved"], } #if message from ProgramDBFlow is associate to a "get_prompt" operation, # save island_id to state if message_state["operation"] == "get_prompt": island_id = input_message.data["retrieved"]["island_id"] to_add_to_state["content"] = { **message_state.get("content",{}), **{"island_id": island_id} } self.merge_message_request_state(msg_id, to_add_to_state) #save from to state self.merge_message_request_state(msg_id, {"from": msg_from}) def call_program_db(self, input_message): """ This method calls the ProgramDBFlow. It sends a message to the ProgramDBFlow with the data of the input message. :param input_message: The input message to send to the ProgramDBFlow. :type input_message: FlowMessage """ #Fetch state associated with the message msg_id = input_message.input_message_id message_state = self.get_message_from_state(input_message.input_message_id) #Get operation and content from state to send to ProgramDBFlow operation = message_state["operation"] content = message_state.get("content", {}) data = { "operation": operation, "content": content } #package message to send to ProgramDBFlow msg = self.package_input_message( data = data, dst_flow = "ProgramDBFlow" ) #If operation is "register_program", # pop message from state (because inital message has been fully handled) and set first_sample_saved_to_db to True #Send a message to register program without expecting a reply (no need to wait for a reply, just save to db and move on) if data["operation"] == "register_program": self.pop_message_from_state(msg_id) self.flow_state["first_sample_saved_to_db"] = True self.subflows["ProgramDBFlow"].send_message( msg ) # If operation is "get_prompt" or "get_best_programs_per_island" # rename key in state to new message id (in order to be able to track of the message in state when the reply arrives) elif data["operation"] in ["get_prompt","get_best_programs_per_island"]: self.rename_key_message_in_state(msg_id, msg.message_id) #if no sample has been saved to db, Send input message back to itself (to try again, hopefully this time a sample will be saved to db) if not self.flow_state["first_sample_saved_to_db"]: #send back to itself message (to try again) self.send_message( input_message ) #If a sample has been saved to db, send message to ProgramDBFlow to fetch prompt or best programs per island else: self.subflows["ProgramDBFlow"].get_reply( msg ) #If operation is not "register_program", "get_prompt" or "get_best_programs_per_island" else: log.error("No operation found, input_message received: \n" + str(input_message)) def call_evaluator(self, input_message): """ This method calls the EvaluatorFlow. It sends a message to the EvaluatorFlow with the data of the input message. :param input_message: The input message to send to the EvaluatorFlow. :type input_message: FlowMessage """ #Fetch state associated with the message msg_id = input_message.input_message_id message_state = self.get_message_from_state(msg_id) #Get data to send to EvaluatorFlow (artifact generated by Sampler to be evaluated) data = { "artifact": message_state["content"]["artifact"] } msg = self.package_input_message( data = data, dst_flow = "EvaluatorFlow" ) # rename key in state to new message id (in order to be able to track of the message in state when the reply arrives) self.rename_key_message_in_state(msg_id, msg.message_id) #Send message to EvaluatorFlow and expect a reply to be sent back to FunSearch's input message queue self.subflows["EvaluatorFlow"].get_reply( msg ) def call_sampler(self, input_message): """ This method calls the SamplerFlow. It sends a message to the SamplerFlow with the data of the input message. :param input_message: The input message to send to the SamplerFlow. :type input_message: FlowMessage """ #Fetch state associated with the message msg_id = input_message.input_message_id message_state = self.get_message_from_state(msg_id) #Get data to send to SamplerFlow (prompt to generate a program) data = { **message_state["retrieved"], } msg = self.package_input_message( data = data, dst_flow = "SamplerFlow" ) # rename key in state to new message id (in order to be able to track of the message in state when the reply arrives) self.rename_key_message_in_state(msg_id, msg.message_id) #send message to SamplerFlow and expect a reply to be sent back to FunSearch's input message queue self.subflows["SamplerFlow"].get_reply( msg ) #If FunSearch is running, make a new request for a prompt (to keep the process going) if self.flow_state["funsearch_running"]: self.make_request_for_prompt() def generate_reply(self, input_message: FlowMessage): """ This method generates a reply to a message sent to user. It packages the output message and sends it. :param input_message: The input message to generate a reply to. :type input_message: FlowMessage """ #Fetch state associated with the message msg_id = input_message.input_message_id message_state = self.pop_message_from_state(msg_id) #Prepare response to send to user (due to a call to get_best_programs_per_island) response = { "retrieved": message_state["retrieved"] } reply = self.package_output_message( message_state["og_message"], response ) self.send_message( reply ) def run(self,input_message: FlowMessage): """ This method runs the flow. It's the main method of the flow. It's called when the flow is executed. :input_message: The input message of the flow :type input_message: Message """ self.register_data_to_state(input_message) next_state = self.get_next_state(input_message) if next_state == "ProgramDBFlow": self.call_program_db(input_message) elif next_state == "EvaluatorFlow": self.call_evaluator(input_message) elif next_state == "SamplerFlow": self.call_sampler(input_message) elif next_state == "GenerateReply": self.generate_reply(input_message) elif next_state == "FunSearch": #If operation is "start", set funsearch_running to True and make a request for a prompt if input_message.data["operation"] == "start": self.flow_state["funsearch_running"] = True self.request_samplers(input_message) #If operation is "stop", set funsearch_running to False (will stop the process of generating new samples) elif input_message.data["operation"] == "stop": self.flow_state["funsearch_running"] = False else: log.error("No next state found, input_message received: \n" + str(input_message))