import os import streamlit as st from st_aggrid import AgGrid import pandas as pd from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer import plotly.express as px # Set the page layout for Streamlit st.set_page_config(layout="wide") # Initialize TAPAS pipeline tqa = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wtq", device="cpu") # Initialize T5 tokenizer and model for text generation t5_tokenizer = T5Tokenizer.from_pretrained("t5-small") t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") # Title and Introduction st.title("Table Question Answering and Data Analysis App") st.markdown(""" This app allows you to upload a table (CSV or Excel) and ask questions about the data. Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing. ### Available Features: - **mean()**: For "average", it computes the mean of the entire numeric DataFrame. - **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame. - **max()**: For "max", it computes the maximum value in the DataFrame. - **min()**: For "min", it computes the minimum value in the DataFrame. - **count()**: For "count", it counts the non-null values in the entire DataFrame. - **Graph Generation**: You can ask questions like "make a graph of column sales?" or "make a graph between sales and expenses?". The app will generate interactive graphs for you. Upload your data and ask questions to get both answers and visualizations. """) # File uploader in the sidebar file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx']) # File processing and question answering if file_name is None: st.markdown('
Please upload an excel or csv file
', unsafe_allow_html=True) else: try: # Check file type and handle reading accordingly if file_name.name.endswith('.csv'): df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed elif file_name.name.endswith('.xlsx'): df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files else: st.error("Unsupported file type") df = None if df is not None: numeric_columns = df.select_dtypes(include=['object']).columns for col in numeric_columns: df[col] = pd.to_numeric(df[col], errors='ignore') st.write("Original Data:") st.write(df) df_numeric = df.copy() df = df.astype(str) # Display the first 5 rows of the dataframe in an editable grid grid_response = AgGrid( df.head(5), fit_columns_on_grid_load=True, # Correct parameter to fit columns on grid load editable=True, height=300, width='100%', ) except Exception as e: st.error(f"Error reading file: {str(e)}") # User input for the question question = st.text_input('Type your question') # Check if the question is about generating a graph is_graph_query = False if 'graph' in question.lower(): is_graph_query = True # Process the answer using TAPAS and T5 with st.spinner(): if st.button('Answer'): try: if not is_graph_query: # Process TAPAS-related questions if it's not a graph query raw_answer = tqa(table=df, query=question, truncation=True) # Display raw answer from TAPAS on the screen st.markdown("Raw TAPAS Answer:
", unsafe_allow_html=True) st.write(raw_answer) # Display the raw TAPAS output # Extract relevant values for Plotly answer = raw_answer.get('answer', '') coordinates = raw_answer.get('coordinates', []) cells = raw_answer.get('cells', []) st.markdown("Relevant Data for Plotly:
", unsafe_allow_html=True) st.write(f"Answer: {answer}") st.write(f"Coordinates: {coordinates}") st.write(f"Cells: {cells}") # If TAPAS is returning a list of numbers for "average" like you mentioned if "average" in question.lower() and cells: # Assuming cells are numeric values that can be plotted in a graph plot_data = [float(cell) for cell in cells] # Convert cells to numeric data # Create a DataFrame for Plotly plot_df = pd.DataFrame({ 'Index': list(range(1, len(plot_data) + 1)), 'Value': plot_data }) # Generate a graph using Plotly fig = px.line(plot_df, x='Index', y='Value', title=f"Graph for '{question}'") st.plotly_chart(fig, use_container_width=True) else: st.write(f"No data to plot for the question: '{question}'") else: # Handle graph-related questions if 'between' in question.lower() and 'and' in question.lower(): columns = question.split('between')[-1].split('and') columns = [col.strip() for col in columns] if len(columns) == 2 and all(col in df.columns for col in columns): fig = px.scatter(df, x=columns[0], y=columns[1], title=f"Graph between {columns[0]} and {columns[1]}") st.plotly_chart(fig, use_container_width=True) st.success(f"Here is the graph between '{columns[0]}' and '{columns[1]}'.") else: st.warning("Columns not found in the dataset.") elif 'column' in question.lower(): column = question.split('of')[-1].strip() if column in df.columns: fig = px.line(df, x=df.index, y=column, title=f"Graph of column '{column}'") st.plotly_chart(fig, use_container_width=True) st.stop() # This halts further execution except Exception as e: st.warning(f"Error processing question or generating answer: {str(e)}") st.warning("Please retype your question and make sure to use the column name and cell value correctly.")