ChatWithDemonstrationsFlowModule / DemonstrationsAtomicFlow.py
nbaldwin's picture
merge with coflows
210a49b
import jinja2
from aiflows.base_flows import AtomicFlow
from aiflows.utils import logging
from aiflows.utils import general_helpers
from typing import Dict,Any,Optional,List
from aiflows.prompt_template import JinjaPrompt
from copy import deepcopy
from aiflows.messages import FlowMessage
import os
import hydra
log = logging.get_logger(__name__)
class DemonstrationsAtomicFlow(AtomicFlow):
""" This class implements a Demonstrations Atomic Flow. It is a flow which is usually used to pass demonstrations (of user assistant interactions)
to the ChatAtomicFlow.
*Configuration Parameters*:
- `name` (str): The name of the flow. Default: "DemonstrationsAtomicFlow"
- `description` (str): A description of the flow. This description is used to generate the help message of the flow.
Default: "A flow that passes demonstrations to the ChatFlow"
- `data` (List[Dict[str, Any]]): The data of the demonstrations.
If data is None, the data is loaded from the file specified in the params["data_dir"].
Default: No default value this field must be set.
- `params` (Dict[str, Any]): The parameters specific to the dataset of the demonstrations. Its default parameters are:
- `data_dir` (str): The directory where the demonstrations are stored. If the data is not directly passed to the flow through `data` then
the data is loaded from this directory. Default: No default value this field must be set.
- `demonstrations_id` (str): The id of the demonstrations (name of the data file). If the data is not directly passed to the flow through `data` then
the data is loaded from this file. Default: No default value this field must be set.
- `demonstrations_k` (int): The number of demonstrations to pass to the ChatFlow.
If None, all the demonstrations are passed to the ChatFlow. Default: None
- `query_prompt_template` (Dict[str, Any]): The prompt template used to generate the query of the demonstrations.
By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one
wants to use the query_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt.
- `response_prompt_template` (Dict[str, Any]): The prompt template used to generate the response of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt.
None of the parameters of the prompt are defined by default and therefore need to be defined if one
wants to use the response_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt.
*Input Interface*:
- The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface is the one expected by ChatAtomicFlow)
*Output Interface*:
- Whichever data that was passed in the input_message (e.g. typically ChatAtomicFlow so the input interface expected by ChatAtomicFlow))
- `demonstrations` (List[Dict[str, Any]]): A list of demonstrations. Each demonstration is a dictionary with the following keys:
- idx (int): The index of the demonstration
- query (str): The query of the demonstration
- response (str): The response of the demonstration
:param params: The parameters specific to the dataset of the demonstrations. It must sould contain the following keys:
- 'data_dir' (str): The directory where the demonstrations are stored. This field is used if the data is not directly passed to the flow through the 'data' field.
- 'demonstrations_id' (str): The id of the demonstrations (name of the data file). This field is used if the data is not directly passed to the flow through the 'data' field.
- 'demonstrations_k' (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow.
- 'ids_to_keep' (Optional[Union[str, List[str]]]): The ids of the demonstrations to keep. If None, all the demonstrations are kept.
:type params: Dict[str, Any]
:param query_prompt_template: The prompt template used to generate the query of the demonstrations.
:type query_prompt_template: JinjaPrompt
:param response_prompt_template: The prompt template used to generate the response of the demonstrations.
:type response_prompt_template: JinjaPrompt
:param data: The data of the demonstrations. If None, the data is loaded from the file specified in the params.
:type data: Optional[List[Dict[str, Any]]]
"""
demonstrations_k: Optional[int] = None
query_prompt_template: JinjaPrompt
response_prompt_template: JinjaPrompt
params: Dict
def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs):
super().__init__(**kwargs)
self.params = params
self.data = data
self.demonstrations_k = self.params.get("demonstrations_k", None)
#typically the query would be what the user (human) asks the assistant (LLM)
self.query_prompt_template = query_prompt_template
#typically the response would be what the assistant (LLM) should answer to the user (human)
self.response_prompt_template = response_prompt_template
if self.data is None:
self._load_data()
@classmethod
def _set_up_prompts(cls, config):
""" This method instantiates the prompt templates of the flow (used when instantiating the flow from a config file)
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: A dictionary of keyword arguments to pass to the constructor of the flow.
:rtype: Dict[str, Any]
"""
kwargs = {}
kwargs["query_prompt_template"] = \
hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial")
kwargs["response_prompt_template"] = \
hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
""" This method instantiates the flow from a config file.
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: The instantiated flow.
:rtype: Flow
"""
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
kwargs.update({"params": flow_config["params"]})
kwargs.update({"data": flow_config["data"]})
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _get_query_message_content(self, sample_data: Dict):
""" This method returns the query message content of a demonstration given the sample data (by rendering the query prompt template).
:param sample_data: The sample data of the demonstration.
:type sample_data: Dict[str, Any]
:return: The query message content of the demonstration.
:rtype: str
"""
input_variables = self.query_prompt_template.input_variables
return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_response_message_content(self, sample_data: Dict):
""" This method returns the response message content of a demonstration given the sample data (by rendering the response prompt template).
:param sample_data: The sample data of the demonstration.
:type sample_data: Dict[str, Any]
:return: The response message content of the demonstration.
:rtype: str
"""
input_variables = self.response_prompt_template.input_variables
return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_io_pair(self, idx):
""" This method, given the index of a demonstration, returns an query-response pair from the demonstrations data.
:param idx: The index of the demonstration.
:type idx: int
:return: The query-response pair at idx from the demonstrations data.
:rtype: Dict[str, Any]
"""
dp = self.data[idx]
query_data = dp["query_data"]
response_data = dp["response_data"]
query = self._get_query_message_content(query_data)
response = self._get_response_message_content(response_data)
return {"idx": idx, "query": query,"response": response}
def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]:
""" This method returns the demonstrations that are passed to the destination flow (typically ChatAtomicFlow).
:param input_data: The input data of the flow.
:type input_data: Dict[str, Any]
:return: The demonstrations that are passed to the destination flow.
:rtype: List[Any]
"""
demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data)
io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)]
return io_pairs
def _load_data(self):
""" This method loads the demonstrations from the file specified in the params. It also filters the demonstrations if the ids_to_keep parameter is specified."""
demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl")
self.data = general_helpers.read_jsonlines(demonstrations_file)
if self.params.get("ids_to_keep", False):
if isinstance(self.params["ids_to_keep"], str):
ids_to_keep = set(self.params["ids_to_keep"].split(","))
else:
ids_to_keep = set(self.params["ids_to_keep"])
self.data = [d for d in self.data if d["id"] in ids_to_keep]
log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"])
def run(self,
input_message: FlowMessage):
""" This method runs the flow. It returns the data of the input_message with the demonstrations added to it.
:param input_message: The input message of the flow.
:type input_message: FlowMessage
"""
input_data = input_message.data
reply = self.package_output_message(
input_message=input_message,
response = {**{"demonstrations": self._get_io_pairs(input_data=input_data)},**input_data}
)
self.send_message(reply)