File size: 10,889 Bytes
82b9374
 
 
 
 
 
 
 
 
 
 
 
798fa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b9374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798fa73
 
 
 
 
 
 
82b9374
 
 
 
 
 
 
 
 
798fa73
 
 
 
 
 
 
82b9374
 
 
 
 
 
 
 
 
 
 
 
798fa73
 
 
 
 
 
 
82b9374
 
 
 
798fa73
 
 
 
 
 
 
82b9374
 
 
 
798fa73
 
 
 
 
 
 
82b9374
 
 
 
 
 
 
 
 
 
 
798fa73
 
 
 
 
 
 
82b9374
 
 
 
 
798fa73
82b9374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798fa73
 
 
 
 
 
 
82b9374
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import jinja2
from flows.base_flows import AtomicFlow
from flows.utils import logging
from flows.utils import general_helpers
from typing import Dict,Any,Optional,List
from flows.prompt_template import JinjaPrompt
from copy import deepcopy
import os
import hydra
log = logging.get_logger(__name__)

class DemonstrationsAtomicFlow(AtomicFlow):
    """ This class implements a Demonstrations Atomic Flow. It is a flow which is usually used to pass demonstrations (of user assistant interactions)
    to the ChatAtomicFlow.
    
    *Configuration Parameters*:
    
    - `name` (str): The name of the flow. Default: "DemonstrationsAtomicFlow"
    - `description` (str): A description of the flow. This description is used to generate the help message of the flow.
    Default: "A flow that passes demonstrations to the ChatFlow"
    - `data` (List[Dict[str, Any]]): The data of the demonstrations.
    If data is None, the data is loaded from the file specified in the params["data_dir"].
    Default: No default value this field must be set.
    - `params` (Dict[str, Any]): The parameters specific to the dataset of the demonstrations. Its default parameters are:
        - `data_dir` (str): The directory where the demonstrations are stored. If the data is not directly passed to the flow through `data` then
        the data is loaded from this directory. Default: No default value this field must be set.
        - `demonstrations_id` (str): The id of the demonstrations (name of the data file). If the data is not directly passed to the flow through `data` then
        the data is loaded from this file. Default: No default value this field must be set.
        - `demonstrations_k` (int): The number of demonstrations to pass to the ChatFlow.
        If None, all the demonstrations are passed to the ChatFlow. Default: None
    - `query_prompt_template` (Dict[str, Any]): The prompt template used to generate the query of the demonstrations. 
    By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one
    wants to use the query_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt.
    - `response_prompt_template` (Dict[str, Any]): The prompt template used to generate the response of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt.
    None of the parameters of the prompt are defined by default and therefore need to be defined if one
    wants to use the response_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt.

    *Input Interface*:
    
    - The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface is the one expected by ChatAtomicFlow)
        
    *Output Interface*:

    - The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface expected by ChatAtomicFlow))
    - `demonstrations` (List[Dict[str, Any]]): A list of demonstrations. Each demonstration is a dictionary with the following keys:
                                                    - idx (int): The index of the demonstration
                                                    - query (str): The query of the demonstration
                                                    - response (str): The response of the demonstration
        
    :param params: The parameters specific to the dataset of the demonstrations. It must sould contain the following keys:
                        - 'data_dir' (str): The directory where the demonstrations are stored. This field is used if the data is not directly passed to the flow through the 'data' field.
                        - 'demonstrations_id' (str): The id of the demonstrations (name of the data file). This field is used if the data is not directly passed to the flow through the 'data' field.
                        - 'demonstrations_k' (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow.
                        - 'ids_to_keep' (Optional[Union[str, List[str]]]): The ids of the demonstrations to keep. If None, all the demonstrations are kept.
    :type params: Dict[str, Any]
    :param query_prompt_template: The prompt template used to generate the query of the demonstrations.
    :type query_prompt_template: JinjaPrompt
    :param response_prompt_template: The prompt template used to generate the response of the demonstrations.
    :type response_prompt_template: JinjaPrompt
    :param data: The data of the demonstrations. If None, the data is loaded from the file specified in the params.
    :type data: Optional[List[Dict[str, Any]]]
    """
    demonstrations_k: Optional[int] = None
    query_prompt_template: JinjaPrompt
    response_prompt_template: JinjaPrompt
    params: Dict
    
    def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs):
        super().__init__(**kwargs)
        self.params = params
        self.data = data
        self.demonstrations_k = self.params.get("demonstrations_k", None)
        
        #typically the query would be what the user (human) asks the assistant (LLM) 
        self.query_prompt_template = query_prompt_template
        #typically the response would be what the assistant (LLM) should answer to the user (human) 
        self.response_prompt_template = response_prompt_template
        if self.data is None:
            self._load_data()
            
    @classmethod
    def _set_up_prompts(cls, config):
        """ This method instantiates the prompt templates of the flow (used when instantiating the flow from a config file)
        
            :param config: The configuration of the flow.
            :type config: Dict[str, Any]
            :return: A dictionary of keyword arguments to pass to the constructor of the flow.
            :rtype: Dict[str, Any]
        """
        kwargs = {}
        kwargs["query_prompt_template"] = \
            hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial")
        kwargs["response_prompt_template"] = \
            hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial")
        return kwargs
        
    @classmethod
    def instantiate_from_config(cls, config):
        """ This method instantiates the flow from a config file. 
        
        :param config: The configuration of the flow.
        :type config: Dict[str, Any]
        :return: The instantiated flow.
        :rtype: Flow
        """
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))
        kwargs.update({"params": flow_config["params"]})
        kwargs.update({"data": flow_config["data"]})
        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)
            
    def _get_query_message_content(self, sample_data: Dict):
        """ This method returns the query message content of a demonstration given the sample data (by rendering the query prompt template).
        
        :param sample_data: The sample data of the demonstration.
        :type sample_data: Dict[str, Any]
        :return: The query message content of the demonstration.
        :rtype: str
        """
        input_variables = self.query_prompt_template.input_variables
        return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables})

    def _get_response_message_content(self, sample_data: Dict):
        """ This method returns the response message content of a demonstration given the sample data (by rendering the response prompt template).
        
        :param sample_data: The sample data of the demonstration.
        :type sample_data: Dict[str, Any]
        :return: The response message content of the demonstration.
        :rtype: str
        """
        input_variables = self.response_prompt_template.input_variables
        return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
        
    def _get_io_pair(self, idx):
        """ This method, given the index of a demonstration, returns an query-response pair from the demonstrations data.
        
        :param idx: The index of the demonstration.
        :type idx: int
        :return: The query-response pair at idx from the demonstrations data.
        :rtype: Dict[str, Any]
        """
        dp = self.data[idx]
        
        query_data = dp["query_data"]
        response_data = dp["response_data"]
        
        query = self._get_query_message_content(query_data)
        response = self._get_response_message_content(response_data)
        
        return {"idx": idx, "query": query,"response": response}
            
    def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]:
        """ This method returns the demonstrations that are passed to the destination flow (typically ChatAtomicFlow).
        
        :param input_data: The input data of the flow.
        :type input_data: Dict[str, Any]
        :return: The demonstrations that are passed to the destination flow.
        :rtype: List[Any]
        """
        demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data)
        io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)]
        return io_pairs
        
    def _load_data(self):
        """ This method loads the demonstrations from the file specified in the params. It also filters the demonstrations if the ids_to_keep parameter is specified."""
        demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl")
        self.data = general_helpers.read_jsonlines(demonstrations_file)
        
        if self.params.get("ids_to_keep", False):
            if isinstance(self.params["ids_to_keep"], str):
                ids_to_keep = set(self.params["ids_to_keep"].split(","))
            else:
                ids_to_keep = set(self.params["ids_to_keep"])

            self.data = [d for d in self.data if d["id"] in ids_to_keep]

        log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"])

    def run(self,
            input_data: Dict[str, Any]) -> Dict[str, Any]:
        """ This method runs the flow. It returns the input data of the flow with the demonstrations added to it.
        
        :param input_data: The input data of the flow.
        :type input_data: Dict[str, Any]
        :return: The input data of the flow with the demonstrations added to it.
        :rtype: Dict[str, Any]
        """
        return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}