Spaces:
Runtime error
Runtime error
File size: 3,309 Bytes
d54fa91 2bb5de3 d54fa91 cbbf201 d54fa91 eb94e92 5d51144 9ae0db9 d54fa91 5d51144 9ae0db9 8183509 d54fa91 8183509 eb94e92 d54fa91 faeec87 bc41f37 eb94e92 d54fa91 eb94e92 faeec87 eb94e92 d54fa91 cbbf201 eb94e92 d54fa91 eb94e92 d54fa91 bc41f37 eb94e92 d54fa91 cbbf201 eb94e92 cbbf201 d54fa91 eb94e92 cbbf201 eb94e92 d54fa91 eb94e92 cbbf201 eb94e92 cbbf201 d54fa91 eb94e92 88e8643 cbbf201 88e8643 cbbf201 d54fa91 eb94e92 cbbf201 eb94e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
from dotenv import load_dotenv
import pandas as pd
import io
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain.tools import PythonAstREPLTool
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.chat_models import ChatOpenAI
from src.types import TableMapping
from src.prompt import (
DATA_SCIENTIST_PROMPT_STR,
SPEC_WRITER_PROMPT_STR,
ENGINEER_PROMPT_STR,
)
load_dotenv()
if os.environ.get("DEBUG") == "true":
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
os.environ["WANDB_PROJECT"] = "llm-data-mapper"
NUM_ROWS_TO_RETURN = 5
DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), "data")
SYNTHETIC_DATA_DIR_PATH = os.path.join(DATA_DIR_PATH, "synthetic")
# TODO: consider different models for different prompts, e.g. natural language prompt might be better with higher temperature
BASE_MODEL = ChatOpenAI(
model_name="gpt-4",
temperature=0,
)
def _get_data_str_from_df_for_prompt(df, num_rows_to_return=NUM_ROWS_TO_RETURN):
return f"<df>\n{df.head(num_rows_to_return).to_markdown()}\n</df>"
def get_table_mapping(source_df, template_df):
"""Use PydanticOutputParser to parse the output of the Data Scientist prompt into a TableMapping object."""
table_mapping_parser = PydanticOutputParser(pydantic_object=TableMapping)
analyst_prompt = ChatPromptTemplate.from_template(
template=DATA_SCIENTIST_PROMPT_STR,
partial_variables={
"format_instructions": table_mapping_parser.get_format_instructions()
},
)
mapping_chain = analyst_prompt | BASE_MODEL | table_mapping_parser
table_mapping: TableMapping = mapping_chain.invoke(
{
"source_1_csv_str": _get_data_str_from_df_for_prompt(source_df),
"target_csv_str": _get_data_str_from_df_for_prompt(template_df),
}
)
return pd.DataFrame(table_mapping.dict()["table_mappings"])
def _sanitize_python_output(text: str):
"""Remove markdown from python code, as prompt returns it."""
_, after = text.split("```python")
return after.split("```")[0]
def generate_mapping_code(table_mapping_df) -> str:
"""Chain two prompts together to generate python code from a table mapping: 1. technical spec writer, 2. python engineer"""
writer_prompt = ChatPromptTemplate.from_template(SPEC_WRITER_PROMPT_STR)
engineer_prompt = ChatPromptTemplate.from_template(ENGINEER_PROMPT_STR)
writer_chain = writer_prompt | BASE_MODEL | StrOutputParser()
engineer_chain = (
{"spec_str": writer_chain}
| engineer_prompt
| BASE_MODEL
| StrOutputParser()
| _sanitize_python_output
)
return engineer_chain.invoke({"table_mapping": str(table_mapping_df.to_dict())})
def process_csv_text(value):
"""Process a CSV file into a dataframe, either from a string path or a file."""
if isinstance(value, str):
df = pd.read_csv(value)
else:
df = pd.read_csv(value.name)
return df
def transform_source(source_df, code_text: str):
"""Use PythonAstREPLTool to transform a source dataframe using python code."""
return PythonAstREPLTool(locals={"source_df": source_df}).run(code_text)
|