File size: 5,003 Bytes
b4c2b4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 22 15:43:16 2024

@author: Raphaël d'Assignies (rdassignies@protonmail.ch)
"""
import json
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from nodes import (GradeResults, GraphState, generate_query_node, 
                   generate_results_node, query_feedback_node, 
                   evaluate_query_node, evaluate_results_node)
import streamlit as st




# Instanciate pipeline
pipeline = StateGraph(GraphState)

pipeline.add_node('generate_query', generate_query_node)
pipeline.add_node('generate_results', generate_results_node)
pipeline.add_node('query_feedback', query_feedback_node)

# Only query
#pipeline.add_edge(START,'generate_query')
#pipeline.add_edge('generate_query', generate_query_node)
#pipeline.add_edge('generate_query', END)

# Full scenario
pipeline.add_edge(START,'generate_query')
pipeline.add_conditional_edges(
    'generate_query', 
    evaluate_query_node, 
    {'error_query' : 'generate_query',
     'ok' : 'generate_results'
     })

pipeline.add_conditional_edges(
    'generate_results', 
    evaluate_results_node,
    {
        "yes": END,
        "no": 'query_feedback',
        "max_generation_reached": END

    }  
)


# Création du graph
graph = pipeline.compile()

# Load le dataframe
df = pd.read_json('bodacc.json', orient='table')

# Initialise le dictionnaire
inputs = {
        'df_head': df.head().to_csv(), 
        'df': df
    }

# Créé un dictionnaire des sorties vide
outputs = {}


# Titre de l'application
st.title("Chat with BODACC !")

# Message d'avertissement
warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} "
                   f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.")

st.warning(warning_message)
# Interface utilisateur pour entrer la requête
user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours")


# Afficher les résultats avec Streamlit
inputs["instructions"] = user_query


# Afficher un bouton pour démarrer la recherche
if st.button("Lancer la recherche"):
    config = {"configurable": {"thread_id": "2"}}
    
    # Étape 1 : Afficher le message "Je réfléchis..."
    st.write("Je réfléchis...")

    # Stream des résultats au fur et à mesure
    with st.spinner('Recherche en cours...'):
        for output in graph.stream(inputs, stream_mode='values', debug=False):
            # Ajouter les résultats au dictionnaire outputs
            for k, v in output.items():
                if k not in outputs:
                    outputs[k] = []
                outputs[k].append(v)

            if "results" in output and len(output["results"]) > 0:
                records = json.loads(output['results'])
                st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.")

    # Après la fin du traitement
    if "results" in outputs and len(outputs["results"]) > 0:
        # Agréger tous les résultats accumulés
        all_results = []
        for res in outputs["results"]:
            json_data = json.loads(res)  # Convertir chaque ensemble de résultats en JSON
            all_results.extend(json_data)  # Accumuler tous les résultats

        results_df = pd.DataFrame(all_results)  # Créer un DataFrame avec tous les résultats accumulés
        # Afficher un aperçu des résultats (jusqu'à 5 premiers)
        num_results = len(results_df)
        st.write(f"J'ai trouvé {num_results} résultats.")
        if num_results > 0:
            preview_count = min(5, num_results)  # Gérer le cas où il y a moins de 5 résultats
            st.write(f"Voici un aperçu des {preview_count} premiers résultats :")
            st.write(results_df.head(preview_count))
            
            trunc = outputs.get('truncated', 'pas de traunc')
            
            if trunc[0] == True:
                st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ")
    
        # Convertir tous les résultats en CSV
        csv = results_df.to_csv(index=False)

        # Ajouter un bouton pour télécharger tous les résultats
        st.download_button(
            label="Télécharger le résultat complet au format CSV",
            data=csv,
            file_name="results.csv",
            mime="text/csv"
        )

    else:
        # Si aucun résultat n'est trouvé
        st.write("Aucun résultat trouvé.")