nbaldwin commited on
Commit
82b9374
·
1 Parent(s): dfac45b

chatwitDemV1

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
ChatWithDemonstrationsFlow.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from flows.base_flows import SequentialFlow
4
+ from flows.utils import logging
5
+
6
+ logging.set_verbosity_debug()
7
+
8
+ log = logging.get_logger(__name__)
9
+
10
+
11
+ class ChatWithDemonstrationsFlow(SequentialFlow):
12
+ def __init__(self,**kwargs):
13
+ super().__init__(**kwargs)
ChatWithDemonstrationsFlow.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "ChatAtomic_Flow_with_Demonstrations"
2
+ description: "A sequential flow that answers questions with demonstrations"
3
+
4
+ subflows_config:
5
+ demonstration_flow:
6
+ _target_: aiflows.ChatWithDemonstrationsFlowModule.DemonstrationsAtomicFlow.instantiate_from_default_config
7
+
8
+ chat_flow:
9
+ _target_: aiflows.OpenAIChatFlowModule.OpenAIChatAtomicFlow.instantiate_from_default_config
10
+
11
+ topology:
12
+ - goal: Get Demonstrations
13
+ input_interface:
14
+ _target_: flows.interfaces.KeyInterface
15
+ # circular flow as the orchestrator, prepare the correct input for the agent
16
+ flow: demonstration_flow
17
+ output_interface:
18
+ _target_: flows.interfaces.KeyInterface
19
+
20
+ - goal: Answer the question
21
+ input_interface:
22
+ _target_: flows.interfaces.KeyInterface
23
+ # circular flow as the orchestrator, prepare the correct input for the agent
24
+ flow: chat_flow
25
+ output_interface:
26
+ _target_: flows.interfaces.KeyInterface
27
+ keys_to_rename:
28
+ api_output: answer # Rename the api_output to answer
DemonstrationsAtomicFlow.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jinja2
2
+ from flows.base_flows import AtomicFlow
3
+ from flows.utils import logging
4
+ from flows.utils import general_helpers
5
+ from typing import Dict,Any,Optional,List
6
+ from flows.prompt_template import JinjaPrompt
7
+ from copy import deepcopy
8
+ import numpy as np
9
+ import os
10
+ import hydra
11
+ log = logging.get_logger(__name__)
12
+
13
+ class DemonstrationsAtomicFlow(AtomicFlow):
14
+ demonstrations_k: Optional[int] = None
15
+ query_prompt_template: JinjaPrompt
16
+ response_prompt_template: JinjaPrompt
17
+ params: Dict
18
+
19
+ def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs):
20
+ super().__init__(**kwargs)
21
+ self.params = params
22
+ self.data = data
23
+ self.demonstrations_k = self.params.get("demonstrations_k", None)
24
+
25
+ #typically the query would be what the user (human) asks the assistant (LLM)
26
+ self.query_prompt_template = query_prompt_template
27
+ #typically the response would be what the assistant (LLM) should answer to the user (human)
28
+ self.response_prompt_template = response_prompt_template
29
+ if self.data is None:
30
+ self._load_data()
31
+
32
+ @classmethod
33
+ def _set_up_prompts(cls, config):
34
+ kwargs = {}
35
+ kwargs["query_prompt_template"] = \
36
+ hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial")
37
+ kwargs["response_prompt_template"] = \
38
+ hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial")
39
+ return kwargs
40
+
41
+ @classmethod
42
+ def instantiate_from_config(cls, config):
43
+ flow_config = deepcopy(config)
44
+
45
+ kwargs = {"flow_config": flow_config}
46
+
47
+ # ~~~ Set up prompts ~~~
48
+ kwargs.update(cls._set_up_prompts(flow_config))
49
+ kwargs.update({"params": flow_config["params"]})
50
+ kwargs.update({"data": flow_config["data"]})
51
+ # ~~~ Instantiate flow ~~~
52
+ return cls(**kwargs)
53
+
54
+ def _get_query_message_content(self, sample_data: Dict):
55
+ input_variables = self.query_prompt_template.input_variables
56
+ return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables})
57
+
58
+ def _get_response_message_content(self, sample_data: Dict):
59
+ input_variables = self.response_prompt_template.input_variables
60
+ return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
61
+
62
+ def _get_io_pair(self, idx):
63
+ dp = self.data[idx]
64
+
65
+ query_data = dp["query_data"]
66
+ response_data = dp["response_data"]
67
+
68
+ query = self._get_query_message_content(query_data)
69
+ response = self._get_response_message_content(response_data)
70
+
71
+ return {"idx": idx, "query": query,"response": response}
72
+
73
+ def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]:
74
+ demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data)
75
+ io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)]
76
+ return io_pairs
77
+
78
+ def _load_data(self):
79
+ demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl")
80
+ self.data = general_helpers.read_jsonlines(demonstrations_file)
81
+
82
+ if self.params.get("ids_to_keep", False):
83
+ if isinstance(self.params["ids_to_keep"], str):
84
+ ids_to_keep = set(self.params["ids_to_keep"].split(","))
85
+ else:
86
+ ids_to_keep = set(self.params["ids_to_keep"])
87
+
88
+ self.data = [d for d in self.data if d["id"] in ids_to_keep]
89
+
90
+ log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"])
91
+
92
+ def run(self,
93
+ input_data: Dict[str, Any]) -> Dict[str, Any]:
94
+ return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}
DemonstrationsAtomicFlow.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "DemonstrationAtomicFlow"
2
+ description: "A flow which returns Demonstrations"
3
+
4
+
5
+
6
+
7
+ data: ??? #e.g. [{"query_data": {"query": "What is the capital of France?"}, "response_data": {"response": "Paris, my sir."}}]
8
+ params:
9
+ data_dir: ???
10
+ demonstrations_id: ???
11
+ ids_to_keep: null
12
+ demonstrations_k: null
13
+
14
+ query_prompt_template:
15
+ _target_: flows.prompt_template.JinjaPrompt
16
+
17
+ response_prompt_template:
18
+ _target_: flows.prompt_template.JinjaPrompt
__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # ~~~ Specify the dependencies ~~
2
+ dependencies = [
3
+ {"url": "aiflows/OpenAIChatFlowModule", "revision": "d69ba2125de99d2edb631dd51d22225ed9e3446c"},
4
+ ]
5
+ from flows import flow_verse
6
+ flow_verse.sync_dependencies(dependencies)
7
+ from .ChatWithDemonstrationsFlow import ChatWithDemonstrationsFlow
8
+ from .DemonstrationsAtomicFlow import DemonstrationsAtomicFlow