|
import jinja2 |
|
from flows.base_flows import AtomicFlow |
|
from flows.utils import logging |
|
from flows.utils import general_helpers |
|
from typing import Dict,Any,Optional,List |
|
from flows.prompt_template import JinjaPrompt |
|
from copy import deepcopy |
|
import numpy as np |
|
import os |
|
import hydra |
|
log = logging.get_logger(__name__) |
|
|
|
class DemonstrationsAtomicFlow(AtomicFlow): |
|
demonstrations_k: Optional[int] = None |
|
query_prompt_template: JinjaPrompt |
|
response_prompt_template: JinjaPrompt |
|
params: Dict |
|
|
|
def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs): |
|
super().__init__(**kwargs) |
|
self.params = params |
|
self.data = data |
|
self.demonstrations_k = self.params.get("demonstrations_k", None) |
|
|
|
|
|
self.query_prompt_template = query_prompt_template |
|
|
|
self.response_prompt_template = response_prompt_template |
|
if self.data is None: |
|
self._load_data() |
|
|
|
@classmethod |
|
def _set_up_prompts(cls, config): |
|
kwargs = {} |
|
kwargs["query_prompt_template"] = \ |
|
hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial") |
|
kwargs["response_prompt_template"] = \ |
|
hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial") |
|
return kwargs |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
kwargs.update({"params": flow_config["params"]}) |
|
kwargs.update({"data": flow_config["data"]}) |
|
|
|
return cls(**kwargs) |
|
|
|
def _get_query_message_content(self, sample_data: Dict): |
|
input_variables = self.query_prompt_template.input_variables |
|
return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
def _get_response_message_content(self, sample_data: Dict): |
|
input_variables = self.response_prompt_template.input_variables |
|
return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
def _get_io_pair(self, idx): |
|
dp = self.data[idx] |
|
|
|
query_data = dp["query_data"] |
|
response_data = dp["response_data"] |
|
|
|
query = self._get_query_message_content(query_data) |
|
response = self._get_response_message_content(response_data) |
|
|
|
return {"idx": idx, "query": query,"response": response} |
|
|
|
def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]: |
|
demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data) |
|
io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)] |
|
return io_pairs |
|
|
|
def _load_data(self): |
|
demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl") |
|
self.data = general_helpers.read_jsonlines(demonstrations_file) |
|
|
|
if self.params.get("ids_to_keep", False): |
|
if isinstance(self.params["ids_to_keep"], str): |
|
ids_to_keep = set(self.params["ids_to_keep"].split(",")) |
|
else: |
|
ids_to_keep = set(self.params["ids_to_keep"]) |
|
|
|
self.data = [d for d in self.data if d["id"] in ids_to_keep] |
|
|
|
log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"]) |
|
|
|
def run(self, |
|
input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}} |
|
|