File size: 3,309 Bytes
d54fa91
2bb5de3
d54fa91
cbbf201
d54fa91
 
 
 
 
 
 
eb94e92
 
 
 
 
5d51144
9ae0db9
d54fa91
5d51144
9ae0db9
 
8183509
d54fa91
8183509
eb94e92
 
d54fa91
faeec87
bc41f37
eb94e92
d54fa91
 
 
eb94e92
faeec87
eb94e92
 
d54fa91
cbbf201
eb94e92
d54fa91
 
eb94e92
 
 
 
d54fa91
bc41f37
eb94e92
 
 
 
 
 
 
 
d54fa91
cbbf201
eb94e92
cbbf201
 
d54fa91
eb94e92
cbbf201
eb94e92
d54fa91
 
eb94e92
cbbf201
eb94e92
 
 
 
 
 
 
cbbf201
d54fa91
eb94e92
88e8643
 
 
 
cbbf201
88e8643
cbbf201
d54fa91
eb94e92
cbbf201
eb94e92
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from dotenv import load_dotenv
import pandas as pd
import io
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain.tools import PythonAstREPLTool
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.chat_models import ChatOpenAI
from src.types import TableMapping
from src.prompt import (
    DATA_SCIENTIST_PROMPT_STR,
    SPEC_WRITER_PROMPT_STR,
    ENGINEER_PROMPT_STR,
)

load_dotenv()

if os.environ.get("DEBUG") == "true":
    os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
    os.environ["WANDB_PROJECT"] = "llm-data-mapper"


NUM_ROWS_TO_RETURN = 5
DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), "data")
SYNTHETIC_DATA_DIR_PATH = os.path.join(DATA_DIR_PATH, "synthetic")

# TODO: consider different models for different prompts, e.g. natural language prompt might be better with higher temperature
BASE_MODEL = ChatOpenAI(
    model_name="gpt-4",
    temperature=0,
)


def _get_data_str_from_df_for_prompt(df, num_rows_to_return=NUM_ROWS_TO_RETURN):
    return f"<df>\n{df.head(num_rows_to_return).to_markdown()}\n</df>"


def get_table_mapping(source_df, template_df):
    """Use PydanticOutputParser to parse the output of the Data Scientist prompt into a TableMapping object."""
    table_mapping_parser = PydanticOutputParser(pydantic_object=TableMapping)
    analyst_prompt = ChatPromptTemplate.from_template(
        template=DATA_SCIENTIST_PROMPT_STR,
        partial_variables={
            "format_instructions": table_mapping_parser.get_format_instructions()
        },
    )
    mapping_chain = analyst_prompt | BASE_MODEL | table_mapping_parser
    table_mapping: TableMapping = mapping_chain.invoke(
        {
            "source_1_csv_str": _get_data_str_from_df_for_prompt(source_df),
            "target_csv_str": _get_data_str_from_df_for_prompt(template_df),
        }
    )
    return pd.DataFrame(table_mapping.dict()["table_mappings"])


def _sanitize_python_output(text: str):
    """Remove markdown from python code, as prompt returns it."""
    _, after = text.split("```python")
    return after.split("```")[0]


def generate_mapping_code(table_mapping_df) -> str:
    """Chain two prompts together to generate python code from a table mapping: 1. technical spec writer, 2. python engineer"""
    writer_prompt = ChatPromptTemplate.from_template(SPEC_WRITER_PROMPT_STR)
    engineer_prompt = ChatPromptTemplate.from_template(ENGINEER_PROMPT_STR)

    writer_chain = writer_prompt | BASE_MODEL | StrOutputParser()
    engineer_chain = (
        {"spec_str": writer_chain}
        | engineer_prompt
        | BASE_MODEL
        | StrOutputParser()
        | _sanitize_python_output
    )
    return engineer_chain.invoke({"table_mapping": str(table_mapping_df.to_dict())})


def process_csv_text(value):
    """Process a CSV file into a dataframe, either from a string path or a file."""
    if isinstance(value, str):
        df = pd.read_csv(value)
    else:
        df = pd.read_csv(value.name)
    return df


def transform_source(source_df, code_text: str):
    """Use PythonAstREPLTool to transform a source dataframe using python code."""
    return PythonAstREPLTool(locals={"source_df": source_df}).run(code_text)