adrien.aribaut-gaudin
fix: gitignore for the database folder + prompt for requirements + 3 blocks max for best_sources
8e58322
import asyncio | |
import os | |
import shutil | |
import json | |
from typing import Dict | |
import random | |
import datetime | |
import string | |
import docx | |
import pandas as pd | |
from src.domain.block import Block | |
from src.tools.doc_tools import get_title | |
from src.domain.doc import Doc | |
from src.domain.wikidoc import WikiPage | |
from src.view.log_msg import create_msg_from | |
import src.tools.semantic_db as semantic_db | |
from src.tools.wiki import Wiki | |
from src.llm.llm_tools import generate_response_to_exigence | |
from src.llm.llm_tools import get_wikilist, get_public_paragraph, get_private_paragraph | |
from src.tools.semantic_db import add_texts_to_collection, query_collection | |
from src.tools.excel_tools import excel_to_dict | |
import gradio as gr | |
from src.retriever.retriever import Retriever | |
class Controller: | |
def __init__(self, config: Dict, client_db, retriever): | |
self.templates_path = config['templates_path'] | |
self.generated_docs_path = config['generated_docs_path'] | |
self.styled_docs_path = config['styled_docs_path'] | |
self.excel_doc_path = config['excel_doc_path'] | |
self.new_docs = [] | |
self.gen_docs = [] | |
self.input_csv = "" | |
template_path = config['templates_path'] + '/' + config['templates'][config['default_template_index']] | |
self.default_template = Doc(template_path) | |
self.template = self.default_template | |
self.log = [] | |
self.differences = [] | |
self.list_differences = [] | |
self.client_db = client_db | |
self.retriever = retriever | |
def copy_docs(self, temp_docs: []): | |
""" | |
Initial copy of the incoming document | |
+ | |
create collection for requirments retrieval | |
+ | |
Initiate paths | |
TODO: Rename or refactor the function -> 1 mission / function | |
TODO: To be tested on several documents | |
TODO: Rename create_collection in create_requirement_collection | |
""" | |
doc_names = [doc.name for doc in temp_docs] | |
for i in range(len(doc_names)): | |
if '/' in doc_names[i]: | |
doc_names[i] = doc_names[i].split('/')[-1] | |
elif '\\' in doc_names[i]: | |
doc_names[i] = doc_names[i].split('\\')[-1] | |
doc_names[i] = doc_names[i].split('.')[0] | |
docs = [Doc(path=doc.name) for doc in temp_docs] | |
self.create_collection(docs) | |
style_paths = [f"{self.generated_docs_path}/{dn}_.docx" for dn in doc_names] | |
gen_paths = [f"{self.generated_docs_path}/{dn}_e.docx" for dn in doc_names] | |
for doc, style_path, gen_path in zip(docs, style_paths, gen_paths): | |
new_doc = doc.copy(style_path) | |
self.new_docs.append(new_doc) | |
def clear_docs(self): | |
for new_doc in self.new_docs: | |
if os.path.exists(new_doc.path): | |
new_doc.clear() | |
for gen_doc in self.gen_docs: | |
if os.path.exists(gen_doc.path): | |
gen_doc.clear() | |
self.new_docs = [] | |
self.gen_docs = [] | |
self.log = [] | |
path_to_clear = os.path.abspath(self.generated_docs_path) | |
second_path_to_clear = os.path.abspath(self.excel_doc_path) | |
[os.remove(f"{path_to_clear}/{doc}") for doc in os.listdir(path_to_clear)] | |
[os.remove(f"{second_path_to_clear}/{doc}") for doc in os.listdir(second_path_to_clear)] | |
def set_template(self, template_name: str = ""): | |
if not template_name: | |
self.template = self.default_template | |
else: | |
template_path = f"{self.templates_path}/{template_name}" | |
self.template = Doc(template_path) | |
def add_template(self, template_path: str): | |
""" | |
TODO: message to be but in config | |
""" | |
if not template_path: | |
return | |
elif not template_path.name.endswith(".docx"): | |
gr.Warning("Seuls les fichiers .docx sont acceptés") | |
return | |
doc = docx.Document(template_path.name) | |
doc.save(self.templates_path + '/' + get_title(template_path.name)) | |
def delete_curr_template(self, template_name: str): | |
if not template_name: | |
return | |
os.remove(f"{self.templates_path}/{template_name}") | |
def retrieve_number_of_misapplied_styles(self): | |
""" | |
not used: buggy !! | |
""" | |
res = {} | |
for new_doc in self.new_docs: | |
res[new_doc] = new_doc.retrieve_number_of_misapplied_styles() | |
return res | |
def get_difference_with_template(self): | |
self.differences = [] | |
for new_doc in self.new_docs: | |
diff_styles = new_doc.get_different_styles_with_template(template=self.template) | |
diff_dicts = [{'doc': new_doc, 'style': s} for s in diff_styles] | |
self.differences += diff_dicts | |
template_styles = self.template.xdoc.styles | |
template_styles = [style for style in template_styles if style.name in self.template.styles.names] | |
return self.differences, template_styles | |
def get_list_styles(self): | |
self.list_differences = [] | |
for new_doc in self.new_docs: | |
list_styles = new_doc.get_list_styles() | |
all_lists_styles = [{'doc': new_doc, 'list_style': s} for s in list_styles] | |
self.list_differences += all_lists_styles | |
return self.list_differences | |
def map_style(self, this_style_index: int, template_style_name: str): | |
""" | |
maps a style from 'this' document into a style from the template | |
""" | |
#dont make any change if the style is already the same | |
diff_dict = self.differences[this_style_index] | |
doc = diff_dict['doc'] | |
this_style_name = diff_dict['style'] | |
log = doc.copy_one_style(this_style_name, template_style_name, self.template) | |
if log: | |
self.log.append({doc.name: log}) | |
def update_list_style(self, this_style_index: int, template_style_name: str): | |
""" | |
maps a style from 'this' document into a style from the template | |
""" | |
#dont make any change if the style is already the same | |
diff_dict = self.list_differences[this_style_index] | |
doc = diff_dict['doc'] | |
this_style_name = diff_dict['list_style'] | |
log = doc.change_bullet_style(this_style_name, template_style_name, self.template) | |
if log: | |
self.log.append({doc.name: log}) | |
def update_style(self,index,style_to_modify): | |
return self.map_style(index, style_to_modify) if style_to_modify else None | |
def apply_template(self, options_list): | |
for new_doc in self.new_docs: | |
log = new_doc.apply_template(template=self.template, options_list=options_list) | |
if log: | |
self.log.append({new_doc.name: log}) | |
def reset(self): | |
for new_doc in self.new_docs: | |
new_doc.delete() | |
for gen_doc in self.gen_docs: | |
gen_doc.delete() | |
self.new_docs = [] | |
self.gen_docs = [] | |
def get_log(self): | |
msg_log = create_msg_from(self.log, self.new_docs) | |
return msg_log | |
""" | |
Source Control | |
""" | |
def get_or_create_collection(self, id_: str) -> str: | |
""" | |
generates a new id if needed | |
TODO: rename into get_or_create_generation_collection | |
TODO: have a single DB with separate collections, one for requirements, one for generation | |
""" | |
if id_ != '-1': | |
return id_ | |
else: | |
now = datetime.datetime.now().strftime("%m%d%H%M") | |
letters = string.ascii_lowercase + string.digits | |
id_ = now + '-' + ''.join(random.choice(letters) for _ in range(10)) | |
semantic_db.get_or_create_collection(id_) | |
return id_ | |
async def wiki_fetch(self) -> [str]: | |
""" | |
returns the title of the wikipages corresponding to the tasks described in the input text | |
""" | |
all_tasks = [] | |
for new_doc in self.new_docs: | |
all_tasks += new_doc.tasks | |
async_tasks = [asyncio.create_task(get_wikilist(task)) for task in all_tasks] | |
wiki_lists = await asyncio.gather(*async_tasks) | |
flatten_wiki_list = list(set().union(*[set(w) for w in wiki_lists])) | |
return flatten_wiki_list | |
async def wiki_upload_and_store(self, wiki_title: str, collection_name: str): | |
""" | |
uploads one wikipage and stores them into the right collection | |
""" | |
wikipage = Wiki().fetch(wiki_title) | |
wiki_title = wiki_title | |
if type(wikipage) != str: | |
texts = WikiPage(wikipage.page_content).get_paragraphs() | |
add_texts_to_collection(coll_name=collection_name, texts=texts, file=wiki_title, source='wiki') | |
else: | |
print(wikipage) | |
""" | |
Generate Control | |
""" | |
async def generate_doc_from_db(self, collection_name: str, from_files: [str]) -> [str]: | |
def query_from_task(task): | |
return get_public_paragraph(task) | |
async def retrieve_text_and_generate(t, collection_name: str, from_files: [str]): | |
""" | |
retreives the texts from the database and generates the documents | |
""" | |
# retreive the texts from the database | |
task_query = query_from_task(t) | |
texts = query_collection(coll_name=collection_name, query=task_query, from_files=from_files) | |
task_resolutions = get_private_paragraph(task=t, texts=texts) | |
return task_resolutions | |
async def real_doc_generation(new_doc): | |
async_task_resolutions = [asyncio.create_task(retrieve_text_and_generate(t=task, collection_name=collection_name, from_files=from_files)) | |
for task in new_doc.tasks] | |
tasks_resolutions = await asyncio.gather(*async_task_resolutions) #A VOIR | |
gen_path = f"{self.generated_docs_path}/{new_doc.name}e.docx" | |
gen_doc = new_doc.copy(gen_path) | |
gen_doc.replace_tasks(tasks_resolutions) | |
gen_doc.save_as_docx() | |
gen_paths.append(gen_doc.path) | |
self.gen_docs.append(gen_doc) | |
return gen_paths | |
gen_paths = [] | |
gen_paths = await asyncio.gather(*[asyncio.create_task(real_doc_generation(new_doc)) for new_doc in self.new_docs]) | |
gen_paths = [path for sublist in gen_paths for path in sublist] | |
gen_paths = list(set(gen_paths)) | |
return gen_paths | |
""" | |
Requirements | |
""" | |
def clear_input_csv(self): | |
self.input_csv = "" | |
[os.remove(f"{self.excel_doc_path}/{doc}") for doc in os.listdir(self.excel_doc_path)] | |
def set_input_csv(self, csv_path: str): | |
""" | |
TODO: rename to set_requirements_file | |
""" | |
self.input_csv = csv_path | |
def create_collection(self, docs: [Doc]): | |
""" | |
TODO: rename to create_requirements_collection | |
TODO: merge with semantic tool to have only one DB Object | |
""" | |
coll_name = "collection_for_docs" | |
collection = self.client_db.get_or_create_collection(coll_name) | |
if collection.count() == 0: | |
for doc in docs: | |
self.fill_collection(doc, collection) | |
self.retriever.collection = collection | |
def fill_collection(self, doc: Doc, collection: str): | |
""" | |
fills the collection with the blocks of the documents | |
""" | |
Retriever(doc=doc, collection=collection) | |
def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9, max_blocks=3) -> [Block]: | |
""" | |
Select the best sources: not far from the very best, not far from the last selected, and not too bad per se | |
""" | |
best_sources = [] | |
for idx, s in enumerate(sources): | |
if idx == 0 \ | |
or (s.distance - sources[idx - 1].distance < delta_1_2 | |
and s.distance - sources[0].distance < delta_1_n) \ | |
or s.distance < absolute: | |
best_sources.append(s) | |
delta_1_2 *= alpha | |
delta_1_n *= alpha | |
absolute *= alpha | |
else: | |
break | |
best_sources = sorted(best_sources, key=lambda x: x.distance)[:max_blocks] | |
return best_sources | |
def generate_response_to_requirements(self): | |
dict_of_excel_content = self.get_requirements_from_csv() | |
for exigence in dict_of_excel_content: | |
blocks_sources = self.retriever.similarity_search(queries = exigence["Exigence"]) | |
best_sources = self._select_best_sources(blocks_sources) | |
sources_contents = [f"Paragraph title : {s.title}\n-----\n{s.content}" if s.title else f"Paragraph {s.index}\n-----\n{s.content}" for s in best_sources] | |
context = '\n'.join(sources_contents) | |
i = 1 | |
while (len(context) > 15000) and i < len(sources_contents): | |
context = "\n".join(sources_contents[:-i]) | |
i += 1 | |
reponse_exigence = generate_response_to_exigence(exigence = exigence["Exigence"], titre_exigence = exigence["Titre"], content = context) | |
dict_of_excel_content[dict_of_excel_content.index(exigence)]["Conformité"] = reponse_exigence | |
dict_of_excel_content[dict_of_excel_content.index(exigence)]["Document"] = best_sources[0].doc | |
dict_of_excel_content[dict_of_excel_content.index(exigence)]["Paragraphes"] = "; ".join([block.index for block in best_sources]) | |
excel_name = self.input_csv | |
if '/' in excel_name: | |
excel_name = excel_name.split('/')[-1] | |
elif '\\' in excel_name: | |
excel_name = excel_name.split('\\')[-1] | |
df = pd.DataFrame(data=dict_of_excel_content) | |
df.to_excel(f"{self.excel_doc_path}/{excel_name}", index=False) | |
return f"{self.excel_doc_path}/{excel_name}" | |
def get_requirements_from_csv(self): | |
excel_content = excel_to_dict(self.input_csv) | |
return excel_content |