|
import pandas as pd |
|
from pathlib import Path |
|
import requests |
|
import streamlit as st |
|
from langchain.agents import create_pandas_dataframe_agent |
|
from langchain.llms import OpenAI |
|
from contextlib import contextmanager, redirect_stdout |
|
from io import StringIO |
|
from time import sleep |
|
|
|
data_path = Path('umpire-full-text.csv') |
|
if not data_path.exists(): |
|
r = requests.get('https://upenn.box.com/shared/static/dyxc1heqrfrgp22ntwpet3f57sir34c3.csv') |
|
data_path.write_bytes(r.content) |
|
|
|
@contextmanager |
|
def st_capture(output_func): |
|
with StringIO() as stdout, redirect_stdout(stdout): |
|
old_write = stdout.write |
|
|
|
def new_write(string): |
|
ret = old_write(string) |
|
output_func(stdout.getvalue()) |
|
return ret |
|
|
|
stdout.write = new_write |
|
yield |
|
|
|
|
|
df = pd.read_csv('umpire-full-text.csv') |
|
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df, verbose=True) |
|
|
|
st.dataframe(df) |
|
|
|
query = st.text_input('Enter query here:', '') |
|
answer = st.empty() |
|
|
|
if query: |
|
output = st.empty() |
|
with st_capture(output.info): |
|
response = agent.run(query) |
|
answer = st.write(response) |
|
|
|
|