Spaces:
Running
Running
File size: 5,003 Bytes
b4c2b4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 22 15:43:16 2024
@author: Raphaël d'Assignies (rdassignies@protonmail.ch)
"""
import json
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from nodes import (GradeResults, GraphState, generate_query_node,
generate_results_node, query_feedback_node,
evaluate_query_node, evaluate_results_node)
import streamlit as st
# Instanciate pipeline
pipeline = StateGraph(GraphState)
pipeline.add_node('generate_query', generate_query_node)
pipeline.add_node('generate_results', generate_results_node)
pipeline.add_node('query_feedback', query_feedback_node)
# Only query
#pipeline.add_edge(START,'generate_query')
#pipeline.add_edge('generate_query', generate_query_node)
#pipeline.add_edge('generate_query', END)
# Full scenario
pipeline.add_edge(START,'generate_query')
pipeline.add_conditional_edges(
'generate_query',
evaluate_query_node,
{'error_query' : 'generate_query',
'ok' : 'generate_results'
})
pipeline.add_conditional_edges(
'generate_results',
evaluate_results_node,
{
"yes": END,
"no": 'query_feedback',
"max_generation_reached": END
}
)
# Création du graph
graph = pipeline.compile()
# Load le dataframe
df = pd.read_json('bodacc.json', orient='table')
# Initialise le dictionnaire
inputs = {
'df_head': df.head().to_csv(),
'df': df
}
# Créé un dictionnaire des sorties vide
outputs = {}
# Titre de l'application
st.title("Chat with BODACC !")
# Message d'avertissement
warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} "
f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.")
st.warning(warning_message)
# Interface utilisateur pour entrer la requête
user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours")
# Afficher les résultats avec Streamlit
inputs["instructions"] = user_query
# Afficher un bouton pour démarrer la recherche
if st.button("Lancer la recherche"):
config = {"configurable": {"thread_id": "2"}}
# Étape 1 : Afficher le message "Je réfléchis..."
st.write("Je réfléchis...")
# Stream des résultats au fur et à mesure
with st.spinner('Recherche en cours...'):
for output in graph.stream(inputs, stream_mode='values', debug=False):
# Ajouter les résultats au dictionnaire outputs
for k, v in output.items():
if k not in outputs:
outputs[k] = []
outputs[k].append(v)
if "results" in output and len(output["results"]) > 0:
records = json.loads(output['results'])
st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.")
# Après la fin du traitement
if "results" in outputs and len(outputs["results"]) > 0:
# Agréger tous les résultats accumulés
all_results = []
for res in outputs["results"]:
json_data = json.loads(res) # Convertir chaque ensemble de résultats en JSON
all_results.extend(json_data) # Accumuler tous les résultats
results_df = pd.DataFrame(all_results) # Créer un DataFrame avec tous les résultats accumulés
# Afficher un aperçu des résultats (jusqu'à 5 premiers)
num_results = len(results_df)
st.write(f"J'ai trouvé {num_results} résultats.")
if num_results > 0:
preview_count = min(5, num_results) # Gérer le cas où il y a moins de 5 résultats
st.write(f"Voici un aperçu des {preview_count} premiers résultats :")
st.write(results_df.head(preview_count))
trunc = outputs.get('truncated', 'pas de traunc')
if trunc[0] == True:
st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ")
# Convertir tous les résultats en CSV
csv = results_df.to_csv(index=False)
# Ajouter un bouton pour télécharger tous les résultats
st.download_button(
label="Télécharger le résultat complet au format CSV",
data=csv,
file_name="results.csv",
mime="text/csv"
)
else:
# Si aucun résultat n'est trouvé
st.write("Aucun résultat trouvé.")
|