ia_back / rag.py
Ilyas KHIAT
enhacnemnet
d70f173
raw
history blame
6.81 kB
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from uuid import uuid4
from prompt import *
import random
from itext2kg.models import KnowledgeGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import os
from langchain_core.tools import tool
import pickle
import unicodedata
load_dotenv()
index_name = os.environ.get("INDEX_NAME")
# Global initialization
embedding_model = "text-embedding-3-small"
embedding = OpenAIEmbeddings(model=embedding_model)
# vector_store = PineconeVectorStore(index=index_name, embedding=embedding)
def advanced_graph_to_json(graph:KnowledgeGraph):
nodes = []
edges = []
for node in graph.entities:
node_id = node.name.replace(" ", "_")
label = node.name
type = node.label
nodes.append({"id": node_id, "label": label, "type": type})
for relationship in graph.relationships:
source = relationship.startEntity
source_id = source.name.replace(" ", "_")
target = relationship.endEntity
target_id = target.name.replace(" ", "_")
label = relationship.name
edges.append({"source": source_id, "label": label, "cible": target_id})
return {"noeuds": nodes, "relations": edges}
with open("kg_ia_signature.pkl", "rb") as file:
loaded_graph = pickle.load(file)
graph = advanced_graph_to_json(loaded_graph)
print("Graph loaded")
with open("chunks_ia_signature.pkl", "rb") as file:
chunks = pickle.load(file)
print("Chunks loaded")
with open("scenes.pkl", "rb") as file:
scenes = pickle.load(file)
print("Scenes loaded")
class sphinx_output(BaseModel):
question: str = Field(description="The question to ask the user to test if they read the entire book")
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book")
class verify_response_model(BaseModel):
response: str = Field(description="The response from the user to the question")
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book")
initial_question: str = Field(description="The question asked to the user to test if they read the entire book")
class verification_score(BaseModel):
score: float = Field(description="The score of the user's response from 0 to 10 to the question")
llm = ChatOpenAI(model="gpt-4o", max_tokens=300, temperature=0.5)
def split_texts(text : str) -> list[str]:
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
is_separator_regex=False,
)
return splitter.split_text(text)
#########################################################################
### PAR ICI , CHOISIR UNE SCENE SPECIFIQUE DANS L'ARGUMENT DE LA FONCTION
def get_random_chunk(scene_specific = [4]) : # scene_specific = None signifie qu'on considère tout le récit
if scene_specific:
scene_specific_content = [scenes[i-1] for i in scene_specific]
scene_specific_content = " ".join(scene_specific_content)
chunks_scene = split_texts(scene_specific_content)
print(f"Scene {scene_specific} has {len(chunks_scene)} chunks")
print([chunk[0:50] for chunk in chunks_scene])
print('---')
chunk_chosen = chunks_scene[random.randint(0, len(chunks_scene) - 1)]
print(f"Chosen chunk: {chunk_chosen}")
return chunk_chosen, scene_specific
return chunks[random.randint(0, len(chunks) - 1)],scene_specific
def get_vectorstore() -> FAISS:
index = faiss.IndexFlatL2(len(embedding.embed_query("hello world")))
vector_store = FAISS(
embedding_function=embedding,
index=index,
docstore=InMemoryDocstore(),
index_to_docstore_id={},
)
documents = [Document(page_content=chunk) for chunk in chunks]
uuids = [str(uuid4()) for _ in range(len(documents))]
vector_store.add_documents(documents=documents, ids=uuids)
return vector_store
vectore_store = get_vectorstore()
def generate_sphinx_response() -> sphinx_output:
writer = "Laurent Tripied"
book_name = "Limites de l'imaginaire ou limites planétaires"
summary = summary_text
excerpt , scene_number = get_random_chunk()
if scene_number:
summary = "scene " + str(scene_number)
prompt = PromptTemplate.from_template(template_sphinx)
structured_llm = llm.with_structured_output(sphinx_output)
# Create an LLM chain with the prompt and the LLM
llm_chain = prompt | structured_llm
return llm_chain.invoke({"writer":writer,"book_name":book_name,"summary":summary,"excerpt":excerpt})
#############################################################
### PAR ICI , CHOISIR LE DEGRE DE SEVERITE DE LA VERIFICATION
def verify_response(response:str,answers:list[str],question:str) -> bool:
prompt = PromptTemplate.from_template(template_verify)
structured_llm = llm.with_structured_output(verification_score)
llm_chain = prompt | structured_llm
score = llm_chain.invoke({"response":response,"answers":answers,"initial_question":question})
if score.score >= 5:
return True
def retrieve_context_from_vectorestore(query:str) -> str:
retriever = vectore_store.as_retriever(search_type="mmr", search_kwargs={"k": 3})
return retriever.invoke(query)
def generate_stream(query:str,messages = [], model = "gpt-4o-mini", max_tokens = 300, temperature = 1,index_name="",stream=True,vector_store=None):
try:
print("init chat")
print("init template")
prompt = PromptTemplate.from_template(template)
writer = "Laurent Tripied"
name_book = "Limites de l'imaginaire ou limites planétaires"
name_icon = "Magritte"
kg = loaded_graph
print("retreiving context")
context = retrieve_context_from_vectorestore(query)
print(f"Context: {context}")
llm_chain = prompt | llm | StrOutputParser()
print("streaming")
if stream:
return llm_chain.stream({"name_book":name_book,"writer":writer,"name_icon":name_icon,"kg":graph,"context":context,"query":query})
else:
return llm_chain.invoke({"name_book":name_book,"writer":writer,"name_icon":name_icon,"kg":graph,"context":context,"query":query})
except Exception as e:
print(e)
return False