|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from langchain.document_loaders.csv_loader import CSVLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.llms import CTransformers |
|
from langchain.chains import ConversationalRetrievalChain |
|
|
|
|
|
def add_vertical_space(spaces=1): |
|
for _ in range(spaces): |
|
st.sidebar.markdown("---") |
|
|
|
def plot_histogram(df): |
|
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist() |
|
if numeric_columns: |
|
selected_column = st.selectbox('Select column for histogram', numeric_columns, key='hist_col') |
|
plt.figure(figsize=(10, 5)) |
|
plt.hist(df[selected_column], bins=20, alpha=0.75) |
|
plt.title(f'Distribution of {selected_column}') |
|
st.pyplot(plt) |
|
else: |
|
st.write("No numeric columns available for plotting.") |
|
|
|
def plot_scatter(df): |
|
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist() |
|
if numeric_columns and len(numeric_columns) > 1: |
|
col1 = st.selectbox('Select the first variable', numeric_columns, key='first_col') |
|
col2 = st.selectbox('Select the second variable', numeric_columns, key='second_col') |
|
plt.figure(figsize=(10, 5)) |
|
plt.scatter(df[col1], df[col2], alpha=0.5) |
|
plt.title(f'Scatter Plot of {col1} vs {col2}') |
|
plt.xlabel(col1) |
|
plt.ylabel(col2) |
|
st.pyplot(plt) |
|
else: |
|
st.write("Need at least two numeric columns to create a scatter plot.") |
|
|
|
def plot_line(df): |
|
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist() |
|
if numeric_columns: |
|
selected_column = st.selectbox('Select column for line plot', numeric_columns, key='line_col') |
|
plt.figure(figsize=(10, 5)) |
|
plt.plot(df[selected_column], marker='o', linestyle='-') |
|
plt.title(f'Line Plot of {selected_column}') |
|
plt.xlabel('Index') |
|
plt.ylabel(selected_column) |
|
st.pyplot(plt) |
|
else: |
|
st.write("No numeric columns available for a line plot.") |
|
|
|
def plot_bar(df): |
|
categorical_columns = df.select_dtypes(include=['object']).columns.tolist() |
|
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist() |
|
if categorical_columns and numeric_columns: |
|
category_col = st.selectbox('Select the category column', categorical_columns, key='cat_col') |
|
numeric_col = st.selectbox('Select the numeric column', numeric_columns, key='num_col') |
|
data_to_plot = df.groupby(category_col)[numeric_col].sum().sort_values(ascending=False) |
|
plt.figure(figsize=(10, 5)) |
|
data_to_plot.plot(kind='bar') |
|
plt.title(f'Bar Chart of {numeric_col} by {category_col}') |
|
plt.xlabel(category_col) |
|
plt.ylabel(f'Sum of {numeric_col}') |
|
st.pyplot(plt) |
|
else: |
|
st.write("No suitable columns available for plotting a bar chart.") |
|
|
|
def main(): |
|
st.set_page_config(page_title="Falcon 7B CSV Chatbot", layout="wide") |
|
st.title("Falcon 7B CSV Chatbot") |
|
|
|
st.sidebar.title("Navigation") |
|
app_mode = st.sidebar.selectbox("Choose the app mode", |
|
["Chat with Llama-2", "Data Visualization"]) |
|
|
|
if app_mode == "Chat with Llama-2": |
|
run_llama_chatbot() |
|
elif app_mode == "Data Visualization": |
|
data_visualization() |
|
|
|
st.sidebar.markdown('''The Falcon 7B CSV Chatbot uses the **Falcon-7B-GGML** model.''') |
|
|
|
def run_llama_chatbot(): |
|
|
|
DB_FAISS_PATH = "vectorstore/db_faiss" |
|
TEMP_DIR = "temp" |
|
|
|
if not os.path.exists(TEMP_DIR): |
|
os.makedirs(TEMP_DIR) |
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'], help="Upload a CSV file") |
|
|
|
add_vertical_space(1) |
|
st.sidebar.markdown('Made by Sunirmala Mohanta') |
|
|
|
if uploaded_file is not None: |
|
file_path = os.path.join(TEMP_DIR, uploaded_file.name) |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.getvalue()) |
|
|
|
st.write(f"Uploaded file: {uploaded_file.name}") |
|
st.write("Processing CSV file...") |
|
|
|
loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','}) |
|
data = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20) |
|
text_chunks = text_splitter.split_documents(data) |
|
|
|
st.write(f"Total text chunks: {len(text_chunks)}") |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') |
|
docsearch = FAISS.from_documents(text_chunks, embeddings) |
|
docsearch.save_local(DB_FAISS_PATH) |
|
|
|
llm = CTransformers(model="models/falcon-7b-instruct.ggccv1.q4_0.bin", |
|
model_type="falcon", |
|
max_new_tokens=512, |
|
temperature=0.1) |
|
|
|
qa = ConversationalRetrievalChain.from_llm(llm, retriever=docsearch.as_retriever()) |
|
|
|
st.write("### Enter your query:") |
|
query = st.text_input("Input Prompt:") |
|
if query: |
|
with st.spinner("Processing your question..."): |
|
chat_history = [] |
|
result = qa({"question": query, "chat_history": chat_history}) |
|
st.write("---") |
|
st.write("### Response:") |
|
st.write(f"> {result['answer']}") |
|
|
|
os.remove(file_path) |
|
|
|
|
|
def data_visualization(): |
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv']) |
|
if uploaded_file is not None: |
|
df = pd.read_csv(uploaded_file) |
|
st.write("Uploaded file preview:") |
|
st.write(df.head()) |
|
|
|
plot_type = st.sidebar.radio("Choose a type of plot:", |
|
('Histogram', 'Scatter Plot', 'Line Plot', 'Bar Chart')) |
|
|
|
if plot_type == 'Histogram': |
|
plot_histogram(df) |
|
elif plot_type == 'Scatter Plot': |
|
plot_scatter(df) |
|
elif plot_type == 'Line Plot': |
|
plot_line(df) |
|
elif plot_type == 'Bar Chart': |
|
plot_bar(df) |
|
|
|
if __name__ == "__main__": |
|
main() |