Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Sun Sep 22 15:43:16 2024 | |
@author: Raphaël d'Assignies (rdassignies@protonmail.ch) | |
""" | |
import json | |
from typing import Literal, Optional, List, Union, Any | |
from langchain_openai import ChatOpenAI | |
import pandas as pd | |
from langchain_core.prompts import ChatPromptTemplate | |
from langgraph.graph import END, StateGraph, START | |
from langchain_core.output_parsers import StrOutputParser | |
from pydantic import BaseModel, Field | |
from models import NatureJugement | |
from nodes import (GradeResults, GraphState, generate_query_node, | |
generate_results_node, query_feedback_node, | |
evaluate_query_node, evaluate_results_node) | |
import streamlit as st | |
# Instanciate pipeline | |
pipeline = StateGraph(GraphState) | |
pipeline.add_node('generate_query', generate_query_node) | |
pipeline.add_node('generate_results', generate_results_node) | |
pipeline.add_node('query_feedback', query_feedback_node) | |
# Only query | |
#pipeline.add_edge(START,'generate_query') | |
#pipeline.add_edge('generate_query', generate_query_node) | |
#pipeline.add_edge('generate_query', END) | |
# Full scenario | |
pipeline.add_edge(START,'generate_query') | |
pipeline.add_conditional_edges( | |
'generate_query', | |
evaluate_query_node, | |
{'error_query' : 'generate_query', | |
'ok' : 'generate_results' | |
}) | |
pipeline.add_conditional_edges( | |
'generate_results', | |
evaluate_results_node, | |
{ | |
"yes": END, | |
"no": 'query_feedback', | |
"max_generation_reached": END | |
} | |
) | |
# Création du graph | |
graph = pipeline.compile() | |
# Load le dataframe | |
df = pd.read_json('bodacc.json', orient='table') | |
# Initialise le dictionnaire | |
inputs = { | |
'df_head': df.head().to_csv(), | |
'df': df | |
} | |
# Créé un dictionnaire des sorties vide | |
outputs = {} | |
# Titre de l'application | |
st.title("Chat with BODACC !") | |
# Message d'avertissement | |
warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} " | |
f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.") | |
st.warning(warning_message) | |
# Interface utilisateur pour entrer la requête | |
user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours") | |
# Afficher les résultats avec Streamlit | |
inputs["instructions"] = user_query | |
# Afficher un bouton pour démarrer la recherche | |
if st.button("Lancer la recherche"): | |
config = {"configurable": {"thread_id": "2"}} | |
# Étape 1 : Afficher le message "Je réfléchis..." | |
st.write("Je réfléchis...") | |
# Stream des résultats au fur et à mesure | |
with st.spinner('Recherche en cours...'): | |
for output in graph.stream(inputs, stream_mode='values', debug=False): | |
# Ajouter les résultats au dictionnaire outputs | |
for k, v in output.items(): | |
if k not in outputs: | |
outputs[k] = [] | |
outputs[k].append(v) | |
if "results" in output and len(output["results"]) > 0: | |
records = json.loads(output['results']) | |
st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.") | |
# Après la fin du traitement | |
if "results" in outputs and len(outputs["results"]) > 0: | |
# Agréger tous les résultats accumulés | |
all_results = [] | |
for res in outputs["results"]: | |
json_data = json.loads(res) # Convertir chaque ensemble de résultats en JSON | |
all_results.extend(json_data) # Accumuler tous les résultats | |
results_df = pd.DataFrame(all_results) # Créer un DataFrame avec tous les résultats accumulés | |
# Afficher un aperçu des résultats (jusqu'à 5 premiers) | |
num_results = len(results_df) | |
st.write(f"J'ai trouvé {num_results} résultats.") | |
if num_results > 0: | |
preview_count = min(5, num_results) # Gérer le cas où il y a moins de 5 résultats | |
st.write(f"Voici un aperçu des {preview_count} premiers résultats :") | |
st.write(results_df.head(preview_count)) | |
trunc = outputs.get('truncated', 'pas de traunc') | |
if trunc[0] == True: | |
st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ") | |
# Convertir tous les résultats en CSV | |
csv = results_df.to_csv(index=False) | |
# Ajouter un bouton pour télécharger tous les résultats | |
st.download_button( | |
label="Télécharger le résultat complet au format CSV", | |
data=csv, | |
file_name="results.csv", | |
mime="text/csv" | |
) | |
else: | |
# Si aucun résultat n'est trouvé | |
st.write("Aucun résultat trouvé.") | |