#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Sep 22 15:43:16 2024 @author: Raphaël d'Assignies (rdassignies@protonmail.ch) """ import json from typing import Literal, Optional, List, Union, Any from langchain_openai import ChatOpenAI import pandas as pd from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import END, StateGraph, START from langchain_core.output_parsers import StrOutputParser from pydantic import BaseModel, Field from models import NatureJugement from nodes import (GradeResults, GraphState, generate_query_node, generate_results_node, query_feedback_node, evaluate_query_node, evaluate_results_node) import streamlit as st # Instanciate pipeline pipeline = StateGraph(GraphState) pipeline.add_node('generate_query', generate_query_node) pipeline.add_node('generate_results', generate_results_node) pipeline.add_node('query_feedback', query_feedback_node) # Only query #pipeline.add_edge(START,'generate_query') #pipeline.add_edge('generate_query', generate_query_node) #pipeline.add_edge('generate_query', END) # Full scenario pipeline.add_edge(START,'generate_query') pipeline.add_conditional_edges( 'generate_query', evaluate_query_node, {'error_query' : 'generate_query', 'ok' : 'generate_results' }) pipeline.add_conditional_edges( 'generate_results', evaluate_results_node, { "yes": END, "no": 'query_feedback', "max_generation_reached": END } ) # Création du graph graph = pipeline.compile() # Load le dataframe df = pd.read_json('bodacc.json', orient='table') # Initialise le dictionnaire inputs = { 'df_head': df.head().to_csv(), 'df': df } # Créé un dictionnaire des sorties vide outputs = {} # Titre de l'application st.title("Chat with BODACC !") # Message d'avertissement warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} " f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.") st.warning(warning_message) # Interface utilisateur pour entrer la requête user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours") # Afficher les résultats avec Streamlit inputs["instructions"] = user_query # Afficher un bouton pour démarrer la recherche if st.button("Lancer la recherche"): config = {"configurable": {"thread_id": "2"}} # Étape 1 : Afficher le message "Je réfléchis..." st.write("Je réfléchis...") # Stream des résultats au fur et à mesure with st.spinner('Recherche en cours...'): for output in graph.stream(inputs, stream_mode='values', debug=False): # Ajouter les résultats au dictionnaire outputs for k, v in output.items(): if k not in outputs: outputs[k] = [] outputs[k].append(v) # Ne pas afficher les messages pour les clés non pertinentes (comme error_query) if 'query' in output and len(output['query'])>0: st.write(f"query : {output['query']}") #st.write(outputs.get('query_feedbacks', 'pas de feedback')) #st.write(outputs.get('results_feedbacks', 'pas de resultfeedback')) if "results" in output and len(output["results"]) > 0: records = json.loads(output['results']) st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.") # Après la fin du traitement if "results" in outputs and len(outputs["results"]) > 0: # Agréger tous les résultats accumulés all_results = [] for res in outputs["results"]: json_data = json.loads(res) # Convertir chaque ensemble de résultats en JSON all_results.extend(json_data) # Accumuler tous les résultats results_df = pd.DataFrame(all_results) # Créer un DataFrame avec tous les résultats accumulés # Afficher un aperçu des résultats (jusqu'à 5 premiers) num_results = len(results_df) st.write(f"J'ai trouvé {num_results} résultats.") if num_results > 0: preview_count = min(5, num_results) # Gérer le cas où il y a moins de 5 résultats st.write(f"Voici un aperçu des {preview_count} premiers résultats :") st.write(results_df.head(preview_count)) trunc = outputs.get('truncated', 'pas de traunc') if trunc[0] == True: st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ") # Convertir tous les résultats en CSV csv = results_df.to_csv(index=False) # Ajouter un bouton pour télécharger tous les résultats st.download_button( label="Télécharger le résultat complet au format CSV", data=csv, file_name="results.csv", mime="text/csv" ) else: # Si aucun résultat n'est trouvé st.write("Aucun résultat trouvé.")