File size: 10,889 Bytes
82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 798fa73 82b9374 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import jinja2
from flows.base_flows import AtomicFlow
from flows.utils import logging
from flows.utils import general_helpers
from typing import Dict,Any,Optional,List
from flows.prompt_template import JinjaPrompt
from copy import deepcopy
import os
import hydra
log = logging.get_logger(__name__)
class DemonstrationsAtomicFlow(AtomicFlow):
""" This class implements a Demonstrations Atomic Flow. It is a flow which is usually used to pass demonstrations (of user assistant interactions)
to the ChatAtomicFlow.
*Configuration Parameters*:
- `name` (str): The name of the flow. Default: "DemonstrationsAtomicFlow"
- `description` (str): A description of the flow. This description is used to generate the help message of the flow.
Default: "A flow that passes demonstrations to the ChatFlow"
- `data` (List[Dict[str, Any]]): The data of the demonstrations.
If data is None, the data is loaded from the file specified in the params["data_dir"].
Default: No default value this field must be set.
- `params` (Dict[str, Any]): The parameters specific to the dataset of the demonstrations. Its default parameters are:
- `data_dir` (str): The directory where the demonstrations are stored. If the data is not directly passed to the flow through `data` then
the data is loaded from this directory. Default: No default value this field must be set.
- `demonstrations_id` (str): The id of the demonstrations (name of the data file). If the data is not directly passed to the flow through `data` then
the data is loaded from this file. Default: No default value this field must be set.
- `demonstrations_k` (int): The number of demonstrations to pass to the ChatFlow.
If None, all the demonstrations are passed to the ChatFlow. Default: None
- `query_prompt_template` (Dict[str, Any]): The prompt template used to generate the query of the demonstrations.
By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one
wants to use the query_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt.
- `response_prompt_template` (Dict[str, Any]): The prompt template used to generate the response of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt.
None of the parameters of the prompt are defined by default and therefore need to be defined if one
wants to use the response_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt.
*Input Interface*:
- The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface is the one expected by ChatAtomicFlow)
*Output Interface*:
- The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface expected by ChatAtomicFlow))
- `demonstrations` (List[Dict[str, Any]]): A list of demonstrations. Each demonstration is a dictionary with the following keys:
- idx (int): The index of the demonstration
- query (str): The query of the demonstration
- response (str): The response of the demonstration
:param params: The parameters specific to the dataset of the demonstrations. It must sould contain the following keys:
- 'data_dir' (str): The directory where the demonstrations are stored. This field is used if the data is not directly passed to the flow through the 'data' field.
- 'demonstrations_id' (str): The id of the demonstrations (name of the data file). This field is used if the data is not directly passed to the flow through the 'data' field.
- 'demonstrations_k' (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow.
- 'ids_to_keep' (Optional[Union[str, List[str]]]): The ids of the demonstrations to keep. If None, all the demonstrations are kept.
:type params: Dict[str, Any]
:param query_prompt_template: The prompt template used to generate the query of the demonstrations.
:type query_prompt_template: JinjaPrompt
:param response_prompt_template: The prompt template used to generate the response of the demonstrations.
:type response_prompt_template: JinjaPrompt
:param data: The data of the demonstrations. If None, the data is loaded from the file specified in the params.
:type data: Optional[List[Dict[str, Any]]]
"""
demonstrations_k: Optional[int] = None
query_prompt_template: JinjaPrompt
response_prompt_template: JinjaPrompt
params: Dict
def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs):
super().__init__(**kwargs)
self.params = params
self.data = data
self.demonstrations_k = self.params.get("demonstrations_k", None)
#typically the query would be what the user (human) asks the assistant (LLM)
self.query_prompt_template = query_prompt_template
#typically the response would be what the assistant (LLM) should answer to the user (human)
self.response_prompt_template = response_prompt_template
if self.data is None:
self._load_data()
@classmethod
def _set_up_prompts(cls, config):
""" This method instantiates the prompt templates of the flow (used when instantiating the flow from a config file)
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: A dictionary of keyword arguments to pass to the constructor of the flow.
:rtype: Dict[str, Any]
"""
kwargs = {}
kwargs["query_prompt_template"] = \
hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial")
kwargs["response_prompt_template"] = \
hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
""" This method instantiates the flow from a config file.
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: The instantiated flow.
:rtype: Flow
"""
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
kwargs.update({"params": flow_config["params"]})
kwargs.update({"data": flow_config["data"]})
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _get_query_message_content(self, sample_data: Dict):
""" This method returns the query message content of a demonstration given the sample data (by rendering the query prompt template).
:param sample_data: The sample data of the demonstration.
:type sample_data: Dict[str, Any]
:return: The query message content of the demonstration.
:rtype: str
"""
input_variables = self.query_prompt_template.input_variables
return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_response_message_content(self, sample_data: Dict):
""" This method returns the response message content of a demonstration given the sample data (by rendering the response prompt template).
:param sample_data: The sample data of the demonstration.
:type sample_data: Dict[str, Any]
:return: The response message content of the demonstration.
:rtype: str
"""
input_variables = self.response_prompt_template.input_variables
return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_io_pair(self, idx):
""" This method, given the index of a demonstration, returns an query-response pair from the demonstrations data.
:param idx: The index of the demonstration.
:type idx: int
:return: The query-response pair at idx from the demonstrations data.
:rtype: Dict[str, Any]
"""
dp = self.data[idx]
query_data = dp["query_data"]
response_data = dp["response_data"]
query = self._get_query_message_content(query_data)
response = self._get_response_message_content(response_data)
return {"idx": idx, "query": query,"response": response}
def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]:
""" This method returns the demonstrations that are passed to the destination flow (typically ChatAtomicFlow).
:param input_data: The input data of the flow.
:type input_data: Dict[str, Any]
:return: The demonstrations that are passed to the destination flow.
:rtype: List[Any]
"""
demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data)
io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)]
return io_pairs
def _load_data(self):
""" This method loads the demonstrations from the file specified in the params. It also filters the demonstrations if the ids_to_keep parameter is specified."""
demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl")
self.data = general_helpers.read_jsonlines(demonstrations_file)
if self.params.get("ids_to_keep", False):
if isinstance(self.params["ids_to_keep"], str):
ids_to_keep = set(self.params["ids_to_keep"].split(","))
else:
ids_to_keep = set(self.params["ids_to_keep"])
self.data = [d for d in self.data if d["id"] in ids_to_keep]
log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"])
def run(self,
input_data: Dict[str, Any]) -> Dict[str, Any]:
""" This method runs the flow. It returns the input data of the flow with the demonstrations added to it.
:param input_data: The input data of the flow.
:type input_data: Dict[str, Any]
:return: The input data of the flow with the demonstrations added to it.
:rtype: Dict[str, Any]
"""
return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}
|