|
import jinja2 |
|
from aiflows.base_flows import AtomicFlow |
|
from aiflows.utils import logging |
|
from aiflows.utils import general_helpers |
|
from typing import Dict,Any,Optional,List |
|
from aiflows.prompt_template import JinjaPrompt |
|
from copy import deepcopy |
|
from aiflows.messages import FlowMessage |
|
import os |
|
import hydra |
|
log = logging.get_logger(__name__) |
|
|
|
class DemonstrationsAtomicFlow(AtomicFlow): |
|
""" This class implements a Demonstrations Atomic Flow. It is a flow which is usually used to pass demonstrations (of user assistant interactions) |
|
to the ChatAtomicFlow. |
|
|
|
*Configuration Parameters*: |
|
|
|
- `name` (str): The name of the flow. Default: "DemonstrationsAtomicFlow" |
|
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. |
|
Default: "A flow that passes demonstrations to the ChatFlow" |
|
- `data` (List[Dict[str, Any]]): The data of the demonstrations. |
|
If data is None, the data is loaded from the file specified in the params["data_dir"]. |
|
Default: No default value this field must be set. |
|
- `params` (Dict[str, Any]): The parameters specific to the dataset of the demonstrations. Its default parameters are: |
|
- `data_dir` (str): The directory where the demonstrations are stored. If the data is not directly passed to the flow through `data` then |
|
the data is loaded from this directory. Default: No default value this field must be set. |
|
- `demonstrations_id` (str): The id of the demonstrations (name of the data file). If the data is not directly passed to the flow through `data` then |
|
the data is loaded from this file. Default: No default value this field must be set. |
|
- `demonstrations_k` (int): The number of demonstrations to pass to the ChatFlow. |
|
If None, all the demonstrations are passed to the ChatFlow. Default: None |
|
- `query_prompt_template` (Dict[str, Any]): The prompt template used to generate the query of the demonstrations. |
|
By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one |
|
wants to use the query_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt. |
|
- `response_prompt_template` (Dict[str, Any]): The prompt template used to generate the response of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt. |
|
None of the parameters of the prompt are defined by default and therefore need to be defined if one |
|
wants to use the response_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt. |
|
|
|
*Input Interface*: |
|
|
|
- The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface is the one expected by ChatAtomicFlow) |
|
|
|
*Output Interface*: |
|
|
|
- Whichever data that was passed in the input_message (e.g. typically ChatAtomicFlow so the input interface expected by ChatAtomicFlow)) |
|
- `demonstrations` (List[Dict[str, Any]]): A list of demonstrations. Each demonstration is a dictionary with the following keys: |
|
- idx (int): The index of the demonstration |
|
- query (str): The query of the demonstration |
|
- response (str): The response of the demonstration |
|
|
|
:param params: The parameters specific to the dataset of the demonstrations. It must sould contain the following keys: |
|
- 'data_dir' (str): The directory where the demonstrations are stored. This field is used if the data is not directly passed to the flow through the 'data' field. |
|
- 'demonstrations_id' (str): The id of the demonstrations (name of the data file). This field is used if the data is not directly passed to the flow through the 'data' field. |
|
- 'demonstrations_k' (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow. |
|
- 'ids_to_keep' (Optional[Union[str, List[str]]]): The ids of the demonstrations to keep. If None, all the demonstrations are kept. |
|
:type params: Dict[str, Any] |
|
:param query_prompt_template: The prompt template used to generate the query of the demonstrations. |
|
:type query_prompt_template: JinjaPrompt |
|
:param response_prompt_template: The prompt template used to generate the response of the demonstrations. |
|
:type response_prompt_template: JinjaPrompt |
|
:param data: The data of the demonstrations. If None, the data is loaded from the file specified in the params. |
|
:type data: Optional[List[Dict[str, Any]]] |
|
""" |
|
demonstrations_k: Optional[int] = None |
|
query_prompt_template: JinjaPrompt |
|
response_prompt_template: JinjaPrompt |
|
params: Dict |
|
|
|
def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs): |
|
super().__init__(**kwargs) |
|
self.params = params |
|
self.data = data |
|
self.demonstrations_k = self.params.get("demonstrations_k", None) |
|
|
|
|
|
self.query_prompt_template = query_prompt_template |
|
|
|
self.response_prompt_template = response_prompt_template |
|
if self.data is None: |
|
self._load_data() |
|
|
|
@classmethod |
|
def _set_up_prompts(cls, config): |
|
""" This method instantiates the prompt templates of the flow (used when instantiating the flow from a config file) |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: A dictionary of keyword arguments to pass to the constructor of the flow. |
|
:rtype: Dict[str, Any] |
|
""" |
|
kwargs = {} |
|
kwargs["query_prompt_template"] = \ |
|
hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial") |
|
kwargs["response_prompt_template"] = \ |
|
hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial") |
|
return kwargs |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
""" This method instantiates the flow from a config file. |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: The instantiated flow. |
|
:rtype: Flow |
|
""" |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
kwargs.update({"params": flow_config["params"]}) |
|
kwargs.update({"data": flow_config["data"]}) |
|
|
|
return cls(**kwargs) |
|
|
|
def _get_query_message_content(self, sample_data: Dict): |
|
""" This method returns the query message content of a demonstration given the sample data (by rendering the query prompt template). |
|
|
|
:param sample_data: The sample data of the demonstration. |
|
:type sample_data: Dict[str, Any] |
|
:return: The query message content of the demonstration. |
|
:rtype: str |
|
""" |
|
input_variables = self.query_prompt_template.input_variables |
|
return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
def _get_response_message_content(self, sample_data: Dict): |
|
""" This method returns the response message content of a demonstration given the sample data (by rendering the response prompt template). |
|
|
|
:param sample_data: The sample data of the demonstration. |
|
:type sample_data: Dict[str, Any] |
|
:return: The response message content of the demonstration. |
|
:rtype: str |
|
""" |
|
input_variables = self.response_prompt_template.input_variables |
|
return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
def _get_io_pair(self, idx): |
|
""" This method, given the index of a demonstration, returns an query-response pair from the demonstrations data. |
|
|
|
:param idx: The index of the demonstration. |
|
:type idx: int |
|
:return: The query-response pair at idx from the demonstrations data. |
|
:rtype: Dict[str, Any] |
|
""" |
|
dp = self.data[idx] |
|
|
|
query_data = dp["query_data"] |
|
response_data = dp["response_data"] |
|
|
|
query = self._get_query_message_content(query_data) |
|
response = self._get_response_message_content(response_data) |
|
|
|
return {"idx": idx, "query": query,"response": response} |
|
|
|
def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]: |
|
""" This method returns the demonstrations that are passed to the destination flow (typically ChatAtomicFlow). |
|
|
|
:param input_data: The input data of the flow. |
|
:type input_data: Dict[str, Any] |
|
:return: The demonstrations that are passed to the destination flow. |
|
:rtype: List[Any] |
|
""" |
|
demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data) |
|
io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)] |
|
return io_pairs |
|
|
|
def _load_data(self): |
|
""" This method loads the demonstrations from the file specified in the params. It also filters the demonstrations if the ids_to_keep parameter is specified.""" |
|
demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl") |
|
self.data = general_helpers.read_jsonlines(demonstrations_file) |
|
|
|
if self.params.get("ids_to_keep", False): |
|
if isinstance(self.params["ids_to_keep"], str): |
|
ids_to_keep = set(self.params["ids_to_keep"].split(",")) |
|
else: |
|
ids_to_keep = set(self.params["ids_to_keep"]) |
|
|
|
self.data = [d for d in self.data if d["id"] in ids_to_keep] |
|
|
|
log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"]) |
|
|
|
def run(self, |
|
input_message: FlowMessage): |
|
""" This method runs the flow. It returns the data of the input_message with the demonstrations added to it. |
|
|
|
:param input_message: The input message of the flow. |
|
:type input_message: FlowMessage |
|
""" |
|
input_data = input_message.data |
|
reply = self.package_output_message( |
|
input_message=input_message, |
|
response = {**{"demonstrations": self._get_io_pairs(input_data=input_data)},**input_data} |
|
) |
|
self.send_message(reply) |