GenProp / src /control /controller.py
adrien.aribaut-gaudin
fix: gitignore for the database folder + prompt for requirements + 3 blocks max for best_sources
8e58322
import asyncio
import os
import shutil
import json
from typing import Dict
import random
import datetime
import string
import docx
import pandas as pd
from src.domain.block import Block
from src.tools.doc_tools import get_title
from src.domain.doc import Doc
from src.domain.wikidoc import WikiPage
from src.view.log_msg import create_msg_from
import src.tools.semantic_db as semantic_db
from src.tools.wiki import Wiki
from src.llm.llm_tools import generate_response_to_exigence
from src.llm.llm_tools import get_wikilist, get_public_paragraph, get_private_paragraph
from src.tools.semantic_db import add_texts_to_collection, query_collection
from src.tools.excel_tools import excel_to_dict
import gradio as gr
from src.retriever.retriever import Retriever
class Controller:
def __init__(self, config: Dict, client_db, retriever):
self.templates_path = config['templates_path']
self.generated_docs_path = config['generated_docs_path']
self.styled_docs_path = config['styled_docs_path']
self.excel_doc_path = config['excel_doc_path']
self.new_docs = []
self.gen_docs = []
self.input_csv = ""
template_path = config['templates_path'] + '/' + config['templates'][config['default_template_index']]
self.default_template = Doc(template_path)
self.template = self.default_template
self.log = []
self.differences = []
self.list_differences = []
self.client_db = client_db
self.retriever = retriever
def copy_docs(self, temp_docs: []):
"""
Initial copy of the incoming document
+
create collection for requirments retrieval
+
Initiate paths
TODO: Rename or refactor the function -> 1 mission / function
TODO: To be tested on several documents
TODO: Rename create_collection in create_requirement_collection
"""
doc_names = [doc.name for doc in temp_docs]
for i in range(len(doc_names)):
if '/' in doc_names[i]:
doc_names[i] = doc_names[i].split('/')[-1]
elif '\\' in doc_names[i]:
doc_names[i] = doc_names[i].split('\\')[-1]
doc_names[i] = doc_names[i].split('.')[0]
docs = [Doc(path=doc.name) for doc in temp_docs]
self.create_collection(docs)
style_paths = [f"{self.generated_docs_path}/{dn}_.docx" for dn in doc_names]
gen_paths = [f"{self.generated_docs_path}/{dn}_e.docx" for dn in doc_names]
for doc, style_path, gen_path in zip(docs, style_paths, gen_paths):
new_doc = doc.copy(style_path)
self.new_docs.append(new_doc)
def clear_docs(self):
for new_doc in self.new_docs:
if os.path.exists(new_doc.path):
new_doc.clear()
for gen_doc in self.gen_docs:
if os.path.exists(gen_doc.path):
gen_doc.clear()
self.new_docs = []
self.gen_docs = []
self.log = []
path_to_clear = os.path.abspath(self.generated_docs_path)
second_path_to_clear = os.path.abspath(self.excel_doc_path)
[os.remove(f"{path_to_clear}/{doc}") for doc in os.listdir(path_to_clear)]
[os.remove(f"{second_path_to_clear}/{doc}") for doc in os.listdir(second_path_to_clear)]
def set_template(self, template_name: str = ""):
if not template_name:
self.template = self.default_template
else:
template_path = f"{self.templates_path}/{template_name}"
self.template = Doc(template_path)
def add_template(self, template_path: str):
"""
TODO: message to be but in config
"""
if not template_path:
return
elif not template_path.name.endswith(".docx"):
gr.Warning("Seuls les fichiers .docx sont acceptés")
return
doc = docx.Document(template_path.name)
doc.save(self.templates_path + '/' + get_title(template_path.name))
def delete_curr_template(self, template_name: str):
if not template_name:
return
os.remove(f"{self.templates_path}/{template_name}")
def retrieve_number_of_misapplied_styles(self):
"""
not used: buggy !!
"""
res = {}
for new_doc in self.new_docs:
res[new_doc] = new_doc.retrieve_number_of_misapplied_styles()
return res
def get_difference_with_template(self):
self.differences = []
for new_doc in self.new_docs:
diff_styles = new_doc.get_different_styles_with_template(template=self.template)
diff_dicts = [{'doc': new_doc, 'style': s} for s in diff_styles]
self.differences += diff_dicts
template_styles = self.template.xdoc.styles
template_styles = [style for style in template_styles if style.name in self.template.styles.names]
return self.differences, template_styles
def get_list_styles(self):
self.list_differences = []
for new_doc in self.new_docs:
list_styles = new_doc.get_list_styles()
all_lists_styles = [{'doc': new_doc, 'list_style': s} for s in list_styles]
self.list_differences += all_lists_styles
return self.list_differences
def map_style(self, this_style_index: int, template_style_name: str):
"""
maps a style from 'this' document into a style from the template
"""
#dont make any change if the style is already the same
diff_dict = self.differences[this_style_index]
doc = diff_dict['doc']
this_style_name = diff_dict['style']
log = doc.copy_one_style(this_style_name, template_style_name, self.template)
if log:
self.log.append({doc.name: log})
def update_list_style(self, this_style_index: int, template_style_name: str):
"""
maps a style from 'this' document into a style from the template
"""
#dont make any change if the style is already the same
diff_dict = self.list_differences[this_style_index]
doc = diff_dict['doc']
this_style_name = diff_dict['list_style']
log = doc.change_bullet_style(this_style_name, template_style_name, self.template)
if log:
self.log.append({doc.name: log})
def update_style(self,index,style_to_modify):
return self.map_style(index, style_to_modify) if style_to_modify else None
def apply_template(self, options_list):
for new_doc in self.new_docs:
log = new_doc.apply_template(template=self.template, options_list=options_list)
if log:
self.log.append({new_doc.name: log})
def reset(self):
for new_doc in self.new_docs:
new_doc.delete()
for gen_doc in self.gen_docs:
gen_doc.delete()
self.new_docs = []
self.gen_docs = []
def get_log(self):
msg_log = create_msg_from(self.log, self.new_docs)
return msg_log
"""
Source Control
"""
def get_or_create_collection(self, id_: str) -> str:
"""
generates a new id if needed
TODO: rename into get_or_create_generation_collection
TODO: have a single DB with separate collections, one for requirements, one for generation
"""
if id_ != '-1':
return id_
else:
now = datetime.datetime.now().strftime("%m%d%H%M")
letters = string.ascii_lowercase + string.digits
id_ = now + '-' + ''.join(random.choice(letters) for _ in range(10))
semantic_db.get_or_create_collection(id_)
return id_
async def wiki_fetch(self) -> [str]:
"""
returns the title of the wikipages corresponding to the tasks described in the input text
"""
all_tasks = []
for new_doc in self.new_docs:
all_tasks += new_doc.tasks
async_tasks = [asyncio.create_task(get_wikilist(task)) for task in all_tasks]
wiki_lists = await asyncio.gather(*async_tasks)
flatten_wiki_list = list(set().union(*[set(w) for w in wiki_lists]))
return flatten_wiki_list
async def wiki_upload_and_store(self, wiki_title: str, collection_name: str):
"""
uploads one wikipage and stores them into the right collection
"""
wikipage = Wiki().fetch(wiki_title)
wiki_title = wiki_title
if type(wikipage) != str:
texts = WikiPage(wikipage.page_content).get_paragraphs()
add_texts_to_collection(coll_name=collection_name, texts=texts, file=wiki_title, source='wiki')
else:
print(wikipage)
"""
Generate Control
"""
async def generate_doc_from_db(self, collection_name: str, from_files: [str]) -> [str]:
def query_from_task(task):
return get_public_paragraph(task)
async def retrieve_text_and_generate(t, collection_name: str, from_files: [str]):
"""
retreives the texts from the database and generates the documents
"""
# retreive the texts from the database
task_query = query_from_task(t)
texts = query_collection(coll_name=collection_name, query=task_query, from_files=from_files)
task_resolutions = get_private_paragraph(task=t, texts=texts)
return task_resolutions
async def real_doc_generation(new_doc):
async_task_resolutions = [asyncio.create_task(retrieve_text_and_generate(t=task, collection_name=collection_name, from_files=from_files))
for task in new_doc.tasks]
tasks_resolutions = await asyncio.gather(*async_task_resolutions) #A VOIR
gen_path = f"{self.generated_docs_path}/{new_doc.name}e.docx"
gen_doc = new_doc.copy(gen_path)
gen_doc.replace_tasks(tasks_resolutions)
gen_doc.save_as_docx()
gen_paths.append(gen_doc.path)
self.gen_docs.append(gen_doc)
return gen_paths
gen_paths = []
gen_paths = await asyncio.gather(*[asyncio.create_task(real_doc_generation(new_doc)) for new_doc in self.new_docs])
gen_paths = [path for sublist in gen_paths for path in sublist]
gen_paths = list(set(gen_paths))
return gen_paths
"""
Requirements
"""
def clear_input_csv(self):
self.input_csv = ""
[os.remove(f"{self.excel_doc_path}/{doc}") for doc in os.listdir(self.excel_doc_path)]
def set_input_csv(self, csv_path: str):
"""
TODO: rename to set_requirements_file
"""
self.input_csv = csv_path
def create_collection(self, docs: [Doc]):
"""
TODO: rename to create_requirements_collection
TODO: merge with semantic tool to have only one DB Object
"""
coll_name = "collection_for_docs"
collection = self.client_db.get_or_create_collection(coll_name)
if collection.count() == 0:
for doc in docs:
self.fill_collection(doc, collection)
self.retriever.collection = collection
def fill_collection(self, doc: Doc, collection: str):
"""
fills the collection with the blocks of the documents
"""
Retriever(doc=doc, collection=collection)
@staticmethod
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9, max_blocks=3) -> [Block]:
"""
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
"""
best_sources = []
for idx, s in enumerate(sources):
if idx == 0 \
or (s.distance - sources[idx - 1].distance < delta_1_2
and s.distance - sources[0].distance < delta_1_n) \
or s.distance < absolute:
best_sources.append(s)
delta_1_2 *= alpha
delta_1_n *= alpha
absolute *= alpha
else:
break
best_sources = sorted(best_sources, key=lambda x: x.distance)[:max_blocks]
return best_sources
def generate_response_to_requirements(self):
dict_of_excel_content = self.get_requirements_from_csv()
for exigence in dict_of_excel_content:
blocks_sources = self.retriever.similarity_search(queries = exigence["Exigence"])
best_sources = self._select_best_sources(blocks_sources)
sources_contents = [f"Paragraph title : {s.title}\n-----\n{s.content}" if s.title else f"Paragraph {s.index}\n-----\n{s.content}" for s in best_sources]
context = '\n'.join(sources_contents)
i = 1
while (len(context) > 15000) and i < len(sources_contents):
context = "\n".join(sources_contents[:-i])
i += 1
reponse_exigence = generate_response_to_exigence(exigence = exigence["Exigence"], titre_exigence = exigence["Titre"], content = context)
dict_of_excel_content[dict_of_excel_content.index(exigence)]["Conformité"] = reponse_exigence
dict_of_excel_content[dict_of_excel_content.index(exigence)]["Document"] = best_sources[0].doc
dict_of_excel_content[dict_of_excel_content.index(exigence)]["Paragraphes"] = "; ".join([block.index for block in best_sources])
excel_name = self.input_csv
if '/' in excel_name:
excel_name = excel_name.split('/')[-1]
elif '\\' in excel_name:
excel_name = excel_name.split('\\')[-1]
df = pd.DataFrame(data=dict_of_excel_content)
df.to_excel(f"{self.excel_doc_path}/{excel_name}", index=False)
return f"{self.excel_doc_path}/{excel_name}"
def get_requirements_from_csv(self):
excel_content = excel_to_dict(self.input_csv)
return excel_content