import jinja2 from flows.base_flows import AtomicFlow from flows.utils import logging from flows.utils import general_helpers from typing import Dict,Any,Optional,List from flows.prompt_template import JinjaPrompt from copy import deepcopy import os import hydra log = logging.get_logger(__name__) class DemonstrationsAtomicFlow(AtomicFlow): """ This class implements a Demonstrations Atomic Flow. It is a flow which is usually used to pass demonstrations (of user assistant interactions) to the ChatAtomicFlow. *Configuration Parameters*: - `name` (str): The name of the flow. Default: "DemonstrationsAtomicFlow" - `description` (str): A description of the flow. This description is used to generate the help message of the flow. Default: "A flow that passes demonstrations to the ChatFlow" - `data` (List[Dict[str, Any]]): The data of the demonstrations. If data is None, the data is loaded from the file specified in the params["data_dir"]. Default: No default value this field must be set. - `params` (Dict[str, Any]): The parameters specific to the dataset of the demonstrations. Its default parameters are: - `data_dir` (str): The directory where the demonstrations are stored. If the data is not directly passed to the flow through `data` then the data is loaded from this directory. Default: No default value this field must be set. - `demonstrations_id` (str): The id of the demonstrations (name of the data file). If the data is not directly passed to the flow through `data` then the data is loaded from this file. Default: No default value this field must be set. - `demonstrations_k` (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow. Default: None - `query_prompt_template` (Dict[str, Any]): The prompt template used to generate the query of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the query_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt. - `response_prompt_template` (Dict[str, Any]): The prompt template used to generate the response of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the response_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt. *Input Interface*: - The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface is the one expected by ChatAtomicFlow) *Output Interface*: - The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface expected by ChatAtomicFlow)) - `demonstrations` (List[Dict[str, Any]]): A list of demonstrations. Each demonstration is a dictionary with the following keys: - idx (int): The index of the demonstration - query (str): The query of the demonstration - response (str): The response of the demonstration :param params: The parameters specific to the dataset of the demonstrations. It must sould contain the following keys: - 'data_dir' (str): The directory where the demonstrations are stored. This field is used if the data is not directly passed to the flow through the 'data' field. - 'demonstrations_id' (str): The id of the demonstrations (name of the data file). This field is used if the data is not directly passed to the flow through the 'data' field. - 'demonstrations_k' (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow. - 'ids_to_keep' (Optional[Union[str, List[str]]]): The ids of the demonstrations to keep. If None, all the demonstrations are kept. :type params: Dict[str, Any] :param query_prompt_template: The prompt template used to generate the query of the demonstrations. :type query_prompt_template: JinjaPrompt :param response_prompt_template: The prompt template used to generate the response of the demonstrations. :type response_prompt_template: JinjaPrompt :param data: The data of the demonstrations. If None, the data is loaded from the file specified in the params. :type data: Optional[List[Dict[str, Any]]] """ demonstrations_k: Optional[int] = None query_prompt_template: JinjaPrompt response_prompt_template: JinjaPrompt params: Dict def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs): super().__init__(**kwargs) self.params = params self.data = data self.demonstrations_k = self.params.get("demonstrations_k", None) #typically the query would be what the user (human) asks the assistant (LLM) self.query_prompt_template = query_prompt_template #typically the response would be what the assistant (LLM) should answer to the user (human) self.response_prompt_template = response_prompt_template if self.data is None: self._load_data() @classmethod def _set_up_prompts(cls, config): """ This method instantiates the prompt templates of the flow (used when instantiating the flow from a config file) :param config: The configuration of the flow. :type config: Dict[str, Any] :return: A dictionary of keyword arguments to pass to the constructor of the flow. :rtype: Dict[str, Any] """ kwargs = {} kwargs["query_prompt_template"] = \ hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial") kwargs["response_prompt_template"] = \ hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial") return kwargs @classmethod def instantiate_from_config(cls, config): """ This method instantiates the flow from a config file. :param config: The configuration of the flow. :type config: Dict[str, Any] :return: The instantiated flow. :rtype: Flow """ flow_config = deepcopy(config) kwargs = {"flow_config": flow_config} # ~~~ Set up prompts ~~~ kwargs.update(cls._set_up_prompts(flow_config)) kwargs.update({"params": flow_config["params"]}) kwargs.update({"data": flow_config["data"]}) # ~~~ Instantiate flow ~~~ return cls(**kwargs) def _get_query_message_content(self, sample_data: Dict): """ This method returns the query message content of a demonstration given the sample data (by rendering the query prompt template). :param sample_data: The sample data of the demonstration. :type sample_data: Dict[str, Any] :return: The query message content of the demonstration. :rtype: str """ input_variables = self.query_prompt_template.input_variables return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables}) def _get_response_message_content(self, sample_data: Dict): """ This method returns the response message content of a demonstration given the sample data (by rendering the response prompt template). :param sample_data: The sample data of the demonstration. :type sample_data: Dict[str, Any] :return: The response message content of the demonstration. :rtype: str """ input_variables = self.response_prompt_template.input_variables return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) def _get_io_pair(self, idx): """ This method, given the index of a demonstration, returns an query-response pair from the demonstrations data. :param idx: The index of the demonstration. :type idx: int :return: The query-response pair at idx from the demonstrations data. :rtype: Dict[str, Any] """ dp = self.data[idx] query_data = dp["query_data"] response_data = dp["response_data"] query = self._get_query_message_content(query_data) response = self._get_response_message_content(response_data) return {"idx": idx, "query": query,"response": response} def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]: """ This method returns the demonstrations that are passed to the destination flow (typically ChatAtomicFlow). :param input_data: The input data of the flow. :type input_data: Dict[str, Any] :return: The demonstrations that are passed to the destination flow. :rtype: List[Any] """ demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data) io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)] return io_pairs def _load_data(self): """ This method loads the demonstrations from the file specified in the params. It also filters the demonstrations if the ids_to_keep parameter is specified.""" demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl") self.data = general_helpers.read_jsonlines(demonstrations_file) if self.params.get("ids_to_keep", False): if isinstance(self.params["ids_to_keep"], str): ids_to_keep = set(self.params["ids_to_keep"].split(",")) else: ids_to_keep = set(self.params["ids_to_keep"]) self.data = [d for d in self.data if d["id"] in ids_to_keep] log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"]) def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """ This method runs the flow. It returns the input data of the flow with the demonstrations added to it. :param input_data: The input data of the flow. :type input_data: Dict[str, Any] :return: The input data of the flow with the demonstrations added to it. :rtype: Dict[str, Any] """ return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}