#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Oct 13 10:30:56 2024 @author: legalchain """ from typing import Literal, Optional, List, Union, Any from langchain_openai import ChatOpenAI import pandas as pd from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import END, StateGraph, START from langchain_core.output_parsers import StrOutputParser from pydantic import BaseModel, Field from models import NatureJugement from prompts import df_prompt, feed_back_prompt, reflection_prompt llm = ChatOpenAI(model="gpt-4o-mini") MAX_GENERATIONS = 2 MAX_ROWS: int = 10 class Query(BaseModel): query:str = Field(..., title="Requête pour filtrer les résultats du dataframe entourée avec des gullemets de type \" ") def clean_query(self): # Correction des échappements dans la chaîne de la requête corrected_query = self.query.replace("\\'", "\\'") # Extraire la condition à l'intérieur des crochets import re condition = re.search(r"df\[(.*)\]", corrected_query).group(1) return condition class GradeResults(BaseModel): binary_score: Literal["yes", "no"] = Field( description="Les résultats sont satisfaisants -> 'yes' ou il y une erreur ou pas de résultats ou les résultats sont améliorables -> 'no'" ) class GraphState(BaseModel): df : Any df_head:str instructions: Optional[str] = None nature_jugement: List = ', '.join([e.value for e in NatureJugement]) region:str = '' dep:str = '' query: Optional[str] = None results :Union[str, List[str]] = [] query_feedbacks: Optional[str] = None results_feedbacks: bool = None generation_num: int = 0 retrieval_num: int = 0 search_mode: Literal["vectorstore", "websearch", "QA_LM"] = "QA_LM" error_query: Optional[Any] = "" error_results: Optional[Any] = "" truncated: bool = False # Méthode pour récupérer le DataFrame def get_df(self) -> pd.DataFrame: return pd.read_json(self.df) # Surcharger l'initialisation pour créer les champs 'region' et 'dep' def __init__(self, **data): super().__init__(**data) # Générer les chaînes pour les régions et départements distinct_regions = self.df['region_nom_officiel'].dropna().unique().tolist() distinct_departements = self.df['departement_nom_officiel'].dropna().unique().tolist() # Convertir en chaînes séparées par des virgules self.region = ', '.join(distinct_regions) self.dep = ', '.join(distinct_departements) def generate_query_node(state: GraphState): prompt = ChatPromptTemplate.from_messages(messages = df_prompt) generate_df_query = prompt | llm.with_structured_output( Query, include_raw=True, # permet de checker les erreurs en sortie ) # TODO : Ajouter le retour erreur de parse_error try : query_generate = generate_df_query.invoke({ 'df_head' : state.df_head, 'instructions' : state.instructions, 'feedback' : state.query_feedbacks, 'error' : state.error_query, 'nature_jugement' : state.nature_jugement, 'dep' : state.dep, 'region': state.region }) query_final = query_generate['parsed'].clean_query() return { "query": query_final, "error_query" : "" # si il ya une erreur cela remet le compteur à zéro } except Exception as e: return {'error_query' : e} def evaluate_query_node(state:GraphState): if state.error_query != "": return "Il y a une erreur dans la requête. Je me suis sûrement trompé. Veuillez réessayer." else: return "ok" def generate_results_node(state:GraphState): try : query = state.query print("query ", query) print('je suis dans generate', type(state.df)) query = eval(query, {"df": state.df}) new_df = state.df[query] print("new_df", new_df.empty) if new_df.empty: return { "generation_num": state.generation_num + 1} elif len(new_df)> MAX_ROWS: return {'results' : new_df.head(MAX_ROWS).to_json(orient='records'), "generation_num": state.generation_num + 1, "truncated": True } else: return {'results' : new_df.to_json(orient='records'), "generation_num": state.generation_num + 1, } except Exception as e : return {'error_results' : e, "generation_num": state.generation_num + 1} def evaluate_results_node(state:GraphState): prompt_eval = ChatPromptTemplate.from_messages(messages=reflection_prompt) generate_eval = prompt_eval | llm.with_structured_output( GradeResults, include_raw=False, # permet de checker les erreurs en sortie ) evaluation = generate_eval.invoke({'df_head' : state.df_head, 'results' :state.results, 'instructions' : state.instructions}) if state.generation_num > MAX_GENERATIONS: return "max_generation_reached" return evaluation.binary_score def query_feedback_node(state: GraphState): prompt_feed_back = ChatPromptTemplate.from_messages(messages=feed_back_prompt) query_feedback_chain = prompt_feed_back| llm |StrOutputParser() feedback = query_feedback_chain.invoke({ "df_head" : state.df_head, "instructions": state.instructions, "results": state.results, "query": state.query }) feedback = f"Evaluation de la recherche : {feedback}" print(feedback) return {"query_feedbacks": feedback}