ChatWithDemonstrationsFlowModule / DemonstrationsAtomicFlow.py
nbaldwin's picture
chatwitDemV1
82b9374
raw
history blame
3.95 kB
import jinja2
from flows.base_flows import AtomicFlow
from flows.utils import logging
from flows.utils import general_helpers
from typing import Dict,Any,Optional,List
from flows.prompt_template import JinjaPrompt
from copy import deepcopy
import numpy as np
import os
import hydra
log = logging.get_logger(__name__)
class DemonstrationsAtomicFlow(AtomicFlow):
demonstrations_k: Optional[int] = None
query_prompt_template: JinjaPrompt
response_prompt_template: JinjaPrompt
params: Dict
def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs):
super().__init__(**kwargs)
self.params = params
self.data = data
self.demonstrations_k = self.params.get("demonstrations_k", None)
#typically the query would be what the user (human) asks the assistant (LLM)
self.query_prompt_template = query_prompt_template
#typically the response would be what the assistant (LLM) should answer to the user (human)
self.response_prompt_template = response_prompt_template
if self.data is None:
self._load_data()
@classmethod
def _set_up_prompts(cls, config):
kwargs = {}
kwargs["query_prompt_template"] = \
hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial")
kwargs["response_prompt_template"] = \
hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
kwargs.update({"params": flow_config["params"]})
kwargs.update({"data": flow_config["data"]})
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _get_query_message_content(self, sample_data: Dict):
input_variables = self.query_prompt_template.input_variables
return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_response_message_content(self, sample_data: Dict):
input_variables = self.response_prompt_template.input_variables
return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_io_pair(self, idx):
dp = self.data[idx]
query_data = dp["query_data"]
response_data = dp["response_data"]
query = self._get_query_message_content(query_data)
response = self._get_response_message_content(response_data)
return {"idx": idx, "query": query,"response": response}
def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]:
demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data)
io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)]
return io_pairs
def _load_data(self):
demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl")
self.data = general_helpers.read_jsonlines(demonstrations_file)
if self.params.get("ids_to_keep", False):
if isinstance(self.params["ids_to_keep"], str):
ids_to_keep = set(self.params["ids_to_keep"].split(","))
else:
ids_to_keep = set(self.params["ids_to_keep"])
self.data = [d for d in self.data if d["id"] in ids_to_keep]
log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"])
def run(self,
input_data: Dict[str, Any]) -> Dict[str, Any]:
return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}